adding calculate_legislator_scores.py summarize_bills.py and extract_bill_topics.py
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||||
@@ -0,0 +1,574 @@
|
|||||||
|
"""Calculate legislator topic scores from bill topics and roll-call votes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import (
|
||||||
|
ColumnElement,
|
||||||
|
Integer,
|
||||||
|
Select,
|
||||||
|
and_,
|
||||||
|
case,
|
||||||
|
cast,
|
||||||
|
delete,
|
||||||
|
extract,
|
||||||
|
func,
|
||||||
|
or_,
|
||||||
|
select,
|
||||||
|
tuple_,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.congress_vote_context import create_score_run, finalize_score_run
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
LegislatorScore,
|
||||||
|
SubjectType,
|
||||||
|
Vote,
|
||||||
|
VoteClassification,
|
||||||
|
VoteEffect,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VoteMeasureRole,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteRecord,
|
||||||
|
)
|
||||||
|
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
|
||||||
|
from pipelines.web.scoring import (
|
||||||
|
OPPOSE_POSITIONS,
|
||||||
|
SUPPORT_POSITIONS,
|
||||||
|
normalized_position_expression,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DELETE_BATCH_SIZE = 5_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScoreDiagnostics:
|
||||||
|
"""Counts for the input stages required to calculate legislator scores."""
|
||||||
|
|
||||||
|
bill_topic_rows: int
|
||||||
|
linked_vote_rows: int
|
||||||
|
vote_record_rows: int
|
||||||
|
topic_vote_links: int
|
||||||
|
scorable_vote_records: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LegislatorScoreInput:
|
||||||
|
"""One aggregated score ready to store in legislator_score."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
year: int
|
||||||
|
topic: str
|
||||||
|
score: float
|
||||||
|
supportive: int
|
||||||
|
opposed: int
|
||||||
|
|
||||||
|
|
||||||
|
def create_legislator_score_query(
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Build the aggregate score query from extracted bill topics and vote records."""
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
supportive_vote = _supportive_vote_expression(normalized_vote)
|
||||||
|
opposed_vote = _opposed_vote_expression(normalized_vote)
|
||||||
|
supportive_count = func.sum(supportive_vote)
|
||||||
|
opposed_count = func.sum(opposed_vote)
|
||||||
|
total_count = supportive_count + opposed_count
|
||||||
|
vote_year = cast(extract("year", Vote.vote_date), Integer)
|
||||||
|
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
VoteRecord.legislator_id.label("legislator_id"),
|
||||||
|
vote_year.label("year"),
|
||||||
|
BillTopic.topic.label("topic"),
|
||||||
|
score,
|
||||||
|
supportive_count.label("supportive"),
|
||||||
|
opposed_count.label("opposed"),
|
||||||
|
total_count.label("total"),
|
||||||
|
)
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
.having(total_count > 0)
|
||||||
|
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
if normalized_topics:
|
||||||
|
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> list[LegislatorScoreInput]:
|
||||||
|
"""Run the aggregate query and return score rows."""
|
||||||
|
rows = session.execute(
|
||||||
|
create_legislator_score_query(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
LegislatorScoreInput(
|
||||||
|
legislator_id=int(row.legislator_id),
|
||||||
|
year=int(row.year),
|
||||||
|
topic=str(row.topic),
|
||||||
|
score=float(row.score),
|
||||||
|
supportive=int(row.supportive),
|
||||||
|
opposed=int(row.opposed),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
if row.score is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_score_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> ScoreDiagnostics:
|
||||||
|
"""Count score pipeline inputs for explaining empty score runs."""
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
|
||||||
|
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
eligible_vote_filters = _eligible_vote_filters()
|
||||||
|
|
||||||
|
bill_topic_rows = session.scalar(
|
||||||
|
select(func.count(BillTopic.id)).where(*topic_filters)
|
||||||
|
)
|
||||||
|
linked_vote_rows = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
vote_record_rows = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
topic_vote_links = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
scorable_vote_records = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*vote_filters,
|
||||||
|
*topic_filters,
|
||||||
|
*eligible_vote_filters,
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoreDiagnostics(
|
||||||
|
bill_topic_rows=bill_topic_rows or 0,
|
||||||
|
linked_vote_rows=linked_vote_rows or 0,
|
||||||
|
vote_record_rows=vote_record_rows or 0,
|
||||||
|
topic_vote_links=topic_vote_links or 0,
|
||||||
|
scorable_vote_records=scorable_vote_records or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def store_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
rows: Sequence[LegislatorScoreInput],
|
||||||
|
*,
|
||||||
|
score_run_id: int | None,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Replace matching score rows and insert the newly calculated scores."""
|
||||||
|
if replace_all:
|
||||||
|
session.execute(delete(LegislatorScore))
|
||||||
|
elif rows:
|
||||||
|
keys = [
|
||||||
|
(row.legislator_id, row.year, row.topic)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
|
||||||
|
session.execute(
|
||||||
|
delete(LegislatorScore).where(
|
||||||
|
tuple_(
|
||||||
|
LegislatorScore.legislator_id,
|
||||||
|
LegislatorScore.year,
|
||||||
|
LegislatorScore.topic,
|
||||||
|
).in_(key_batch)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
LegislatorScore(
|
||||||
|
legislator_id=row.legislator_id,
|
||||||
|
year=row.year,
|
||||||
|
topic=row.topic,
|
||||||
|
score=row.score,
|
||||||
|
score_run_id=score_run_id,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _supportive_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _opposed_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _position_matches_effect(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
effect: VoteEffect,
|
||||||
|
) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
|
||||||
|
VotePositionMeaning.yea_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
|
||||||
|
VotePositionMeaning.nay_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote == "present",
|
||||||
|
VotePositionMeaning.present_effect == effect,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
|
||||||
|
normalized: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for topic in topics or []:
|
||||||
|
value = normalize_topic_label(topic)
|
||||||
|
if value and value not in seen:
|
||||||
|
normalized.append(value)
|
||||||
|
seen.add(value)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
|
||||||
|
return [
|
||||||
|
items[index : index + batch_size]
|
||||||
|
for index in range(0, len(items), batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _vote_scope_filters(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
filters.append(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _topic_scope_filters(
|
||||||
|
*,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str],
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
|
||||||
|
if topics:
|
||||||
|
filters.append(BillTopic.topic.in_(list(topics)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _has_score_scope(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str] | None,
|
||||||
|
) -> bool:
|
||||||
|
return congress is not None or bool(bill_ids) or bool(topics)
|
||||||
|
|
||||||
|
|
||||||
|
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
|
||||||
|
return [
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
congress: Annotated[
|
||||||
|
int | None,
|
||||||
|
typer.Option(help="Only score votes from one Congress."),
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only score votes linked to one internal bill.id. Repeatable.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
topics: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
|
||||||
|
] = None,
|
||||||
|
replace_all: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
help="Delete every existing legislator score before inserting. "
|
||||||
|
"Unfiltered runs do this automatically."
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Calculate scores without writing to the database."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts even when rows are calculated."),
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for calculating and storing legislator topic scores."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
rows = collect_legislator_scores(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
logger.info("Calculated %d legislator topic score rows", len(rows))
|
||||||
|
if diagnose or not rows:
|
||||||
|
diagnostics = collect_score_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
_log_diagnostics(diagnostics)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
session.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
score_run = create_score_run(session)
|
||||||
|
should_replace_all = replace_all or not _has_score_scope(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
written = store_legislator_scores(
|
||||||
|
session,
|
||||||
|
rows,
|
||||||
|
score_run_id=score_run.id,
|
||||||
|
replace_all=should_replace_all,
|
||||||
|
)
|
||||||
|
included_vote_count = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
|
||||||
|
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
|
||||||
|
)
|
||||||
|
) or 0
|
||||||
|
total_scoped_votes = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
|
||||||
|
) or 0
|
||||||
|
finalize_score_run(
|
||||||
|
session,
|
||||||
|
score_run=score_run,
|
||||||
|
included_vote_count=included_vote_count,
|
||||||
|
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored %d legislator topic score rows", written)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
|
||||||
|
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.linked_vote_rows,
|
||||||
|
diagnostics.vote_record_rows,
|
||||||
|
diagnostics.topic_vote_links,
|
||||||
|
diagnostics.scorable_vote_records,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_topic_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No extracted bill topics matched the score scope. Run "
|
||||||
|
"pipelines.tools.extract_bill_topics after bill summarization."
|
||||||
|
)
|
||||||
|
elif diagnostics.linked_vote_rows == 0:
|
||||||
|
logger.warning("No direct substantive text votes matched the score scope.")
|
||||||
|
elif diagnostics.vote_record_rows == 0:
|
||||||
|
logger.warning("No individual vote records matched the score scope.")
|
||||||
|
elif diagnostics.topic_vote_links == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Bill topics exist, but none are attached to bills that have eligible scored votes."
|
||||||
|
)
|
||||||
|
elif diagnostics.scorable_vote_records == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,682 @@
|
|||||||
|
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any, Sequence
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||||
|
|
||||||
|
|
||||||
|
class TopicExtractionError(RuntimeError):
|
||||||
|
"""Raised when a topic extraction request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicCatalog:
|
||||||
|
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]]
|
||||||
|
candidate_topics: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicExtractionDiagnostics:
|
||||||
|
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||||
|
|
||||||
|
bill_rows: int
|
||||||
|
bill_text_rows: int
|
||||||
|
summarized_bill_text_rows: int
|
||||||
|
bills_with_summaries: int
|
||||||
|
bill_topic_rows: int
|
||||||
|
selected_bills: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExtractedBillTopic:
|
||||||
|
"""One extracted bill topic and yes-vote stance."""
|
||||||
|
|
||||||
|
topic: str
|
||||||
|
support_position: BillTopicPosition
|
||||||
|
confidence: float | None = None
|
||||||
|
evidence: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||||
|
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||||
|
for bill_text in bill.bill_texts:
|
||||||
|
if bill_text.summary and bill_text.summary.strip():
|
||||||
|
return bill_text
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_topic_label(value: str) -> str:
|
||||||
|
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||||
|
normalized = value.strip().strip("\"'")
|
||||||
|
normalized = normalized.strip().rstrip(".").strip()
|
||||||
|
return re.sub(r"\s+", " ", normalized).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||||
|
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||||
|
topics_path = path or DEFAULT_TOPICS_PATH
|
||||||
|
try:
|
||||||
|
raw = json.loads(topics_path.read_text())
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
msg = f"Topic catalog not found: {topics_path}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]] = {}
|
||||||
|
candidate_topics: list[str] = []
|
||||||
|
seen_topics: set[str] = set()
|
||||||
|
|
||||||
|
for category, topics in raw.items():
|
||||||
|
if not isinstance(category, str) or not category.strip():
|
||||||
|
msg = "Topic catalog category names must be non-empty strings"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = f"Topic catalog category {category!r} must contain a list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
normalized_topics: list[str] = []
|
||||||
|
for topic in topics:
|
||||||
|
if not isinstance(topic, str):
|
||||||
|
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
normalized_topic = normalize_topic_label(topic)
|
||||||
|
if not normalized_topic:
|
||||||
|
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if normalized_topic in seen_topics:
|
||||||
|
continue
|
||||||
|
seen_topics.add(normalized_topic)
|
||||||
|
normalized_topics.append(normalized_topic)
|
||||||
|
candidate_topics.append(normalized_topic)
|
||||||
|
|
||||||
|
topics_by_category[category.strip()] = normalized_topics
|
||||||
|
|
||||||
|
return TopicCatalog(
|
||||||
|
topics_by_category=topics_by_category,
|
||||||
|
candidate_topics=candidate_topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_topic_extraction_messages(
|
||||||
|
*,
|
||||||
|
bill: Bill,
|
||||||
|
bill_text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build GPT messages for extracting a bill's scored topics."""
|
||||||
|
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||||
|
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||||
|
metadata = "\n".join(
|
||||||
|
(
|
||||||
|
f"Congress: {bill.congress}",
|
||||||
|
f"Bill: {bill.bill_type} {bill.number}",
|
||||||
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||||
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You extract policy topics from U.S. congressional bills.\n"
|
||||||
|
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||||
|
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||||
|
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||||
|
"Select only topics from the provided candidate topic list.\n"
|
||||||
|
"Omit topics that are not materially addressed by the bill.\n"
|
||||||
|
"Return strict JSON only, with this shape:\n"
|
||||||
|
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||||
|
)
|
||||||
|
user_prompt = "\n\n".join(
|
||||||
|
(
|
||||||
|
"BILL METADATA:",
|
||||||
|
metadata,
|
||||||
|
"CANDIDATE TOPICS:",
|
||||||
|
candidate_list,
|
||||||
|
"BILL TEXT:",
|
||||||
|
bill_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_topic_extraction(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
openai_config.openai_chat_completions_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {openai_config.api_key}",
|
||||||
|
"OpenAI-Project": openai_config.openai_project_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": "gpt-5.4-mini",
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=openai_config.timeout_seconds,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||||
|
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||||
|
payload = _load_json_response(response_text)
|
||||||
|
topics = payload.get("topics")
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = "Topic extraction response must contain a topics list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||||
|
for item in topics:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
msg = "Topic extraction response topics must be objects"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_topic = _extract_topic_label(item)
|
||||||
|
topic = normalize_topic_label(raw_topic)
|
||||||
|
if not topic:
|
||||||
|
msg = "Topic extraction response topic must not be blank"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_position = item.get("support_position")
|
||||||
|
try:
|
||||||
|
support_position = BillTopicPosition(raw_position)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = f"Invalid support_position: {raw_position!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
confidence = _parse_confidence(item.get("confidence"))
|
||||||
|
evidence = item.get("evidence")
|
||||||
|
if evidence is not None and not isinstance(evidence, str):
|
||||||
|
evidence = str(evidence)
|
||||||
|
|
||||||
|
extracted = ExtractedBillTopic(
|
||||||
|
topic=topic,
|
||||||
|
support_position=support_position,
|
||||||
|
confidence=confidence,
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
key = (topic, support_position)
|
||||||
|
existing = deduped.get(key)
|
||||||
|
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||||
|
deduped[key] = extracted
|
||||||
|
|
||||||
|
return list(deduped.values())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_topics_for_bill_text(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
bill: Bill,
|
||||||
|
text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[ExtractedBillTopic]:
|
||||||
|
"""Extract accepted catalog topics for a bill text string."""
|
||||||
|
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||||
|
messages = build_topic_extraction_messages(
|
||||||
|
bill=bill,
|
||||||
|
bill_text=text,
|
||||||
|
candidate_topics=sorted(normalized_candidates),
|
||||||
|
)
|
||||||
|
response_text = call_openai_topic_extraction(
|
||||||
|
openai_config=openai_config,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
extracted_topics = parse_topic_extraction_response(response_text)
|
||||||
|
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_topic_result(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
bill: Bill,
|
||||||
|
topics: Sequence[ExtractedBillTopic],
|
||||||
|
replace_existing: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Store extracted topics for one bill."""
|
||||||
|
if replace_existing:
|
||||||
|
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
session.add(
|
||||||
|
BillTopic(
|
||||||
|
bill_id=bill.id,
|
||||||
|
topic=normalize_topic_label(topic.topic),
|
||||||
|
support_position=topic.support_position,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bills_for_topic_extraction(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select[tuple[Bill]]:
|
||||||
|
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
*has_summary,
|
||||||
|
]
|
||||||
|
if with_votes_only:
|
||||||
|
summarized_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||||
|
stmt = (
|
||||||
|
select(Bill)
|
||||||
|
.where(summarized_text_exists)
|
||||||
|
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||||
|
.order_by(Bill.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
selected_text_exists = exists(
|
||||||
|
select(BillText.id).where(
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
BillText.id.in_(bill_text_ids),
|
||||||
|
*summarized_text_filters[1:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stmt = stmt.where(selected_text_exists)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||||
|
)
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_topic_extraction_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> TopicExtractionDiagnostics:
|
||||||
|
"""Count topic extraction inputs for explaining empty selections."""
|
||||||
|
bill_filters = []
|
||||||
|
bill_text_filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
bill_filters.append(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
bill_filters.append(Bill.id.in_(bill_ids))
|
||||||
|
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
bill_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summary_filters = [*bill_text_filters, *has_summary]
|
||||||
|
|
||||||
|
bills_with_summaries = session.scalar(
|
||||||
|
select(func.count(func.distinct(Bill.id)))
|
||||||
|
.select_from(Bill)
|
||||||
|
.join(BillText, BillText.bill_id == Bill.id)
|
||||||
|
.where(*bill_filters, *summary_filters)
|
||||||
|
)
|
||||||
|
selected_bills = session.scalar(
|
||||||
|
select(func.count()).select_from(
|
||||||
|
create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
).subquery()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TopicExtractionDiagnostics(
|
||||||
|
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||||
|
bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=bill_text_filters,
|
||||||
|
),
|
||||||
|
summarized_bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=summary_filters,
|
||||||
|
),
|
||||||
|
bills_with_summaries=bills_with_summaries or 0,
|
||||||
|
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||||
|
selected_bills=selected_bills or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||||
|
text = response_text.strip()
|
||||||
|
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
text = fenced.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
msg = "Topic extraction response must be a JSON object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_confidence(raw: Any) -> float | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
msg = f"Invalid confidence: {raw!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||||
|
if topic.confidence is None:
|
||||||
|
return (0, 0.0)
|
||||||
|
return (1, topic.confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||||
|
raw_topic = item.get("topic")
|
||||||
|
if isinstance(raw_topic, str):
|
||||||
|
return raw_topic
|
||||||
|
if isinstance(raw_topic, dict):
|
||||||
|
for key in ("topic", "label", "name", "title"):
|
||||||
|
value = raw_topic.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
msg = "Topic extraction response topic must be a string"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_bill_texts(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
bill_filters: Sequence[ColumnElement[bool]],
|
||||||
|
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||||
|
) -> int:
|
||||||
|
stmt = select(func.count(BillText.id))
|
||||||
|
if bill_filters:
|
||||||
|
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||||
|
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
topics_path: Annotated[
|
||||||
|
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||||
|
] = DEFAULT_TOPICS_PATH,
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||||
|
] = False,
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts before processing."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill topics."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_catalog = load_topic_catalog(topics_path)
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d candidate topics from %s",
|
||||||
|
len(topic_catalog.candidate_topics),
|
||||||
|
topics_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
if diagnose or dry_run:
|
||||||
|
diagnostics = collect_topic_extraction_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
_log_topic_extraction_diagnostics(diagnostics)
|
||||||
|
if dry_run:
|
||||||
|
return
|
||||||
|
|
||||||
|
openai_config = get_openai_config()
|
||||||
|
|
||||||
|
stmt = create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bills = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
failed = 0
|
||||||
|
for index, bill in enumerate(bills, 1):
|
||||||
|
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||||
|
if bill_text is None:
|
||||||
|
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||||
|
continue
|
||||||
|
summary = bill_text.summary.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_topics = extract_topics_for_bill_text(
|
||||||
|
openai_config=openai_config,
|
||||||
|
bill=bill,
|
||||||
|
text=summary,
|
||||||
|
candidate_topics=topic_catalog.candidate_topics,
|
||||||
|
)
|
||||||
|
except (httpx.HTTPError, TopicExtractionError):
|
||||||
|
failed += 1
|
||||||
|
logger.exception(
|
||||||
|
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
store_bill_topic_result(
|
||||||
|
session=session,
|
||||||
|
bill=bill,
|
||||||
|
topics=extracted_topics,
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
written += 1
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Stored %d topics for bill id=%s",
|
||||||
|
len(extracted_topics),
|
||||||
|
bill.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Done: stored topic results for %d bills; failed %d bills",
|
||||||
|
written,
|
||||||
|
failed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_topic_extraction_diagnostics(
|
||||||
|
diagnostics: TopicExtractionDiagnostics,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||||
|
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||||
|
"bill_topic_rows=%d selected_bills=%d",
|
||||||
|
diagnostics.bill_rows,
|
||||||
|
diagnostics.bill_text_rows,
|
||||||
|
diagnostics.summarized_bill_text_rows,
|
||||||
|
diagnostics.bills_with_summaries,
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.selected_bills,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_rows == 0:
|
||||||
|
logger.warning("No bills matched the topic extraction scope.")
|
||||||
|
elif diagnostics.bill_text_rows == 0:
|
||||||
|
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||||
|
elif diagnostics.summarized_bill_text_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No summarized bill_text rows matched the topic extraction scope. "
|
||||||
|
"Run pipelines.tools.summarize_bills first."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||||
|
logger.warning(
|
||||||
|
"No bills selected because matching bills already have topics. "
|
||||||
|
"Use --force to regenerate them."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0:
|
||||||
|
logger.warning("No bills selected for topic extraction.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import Select, exists, or_, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from tiktoken import get_encoding
|
||||||
|
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def load_summarization_prompts(
|
||||||
|
section: str = "summarization",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
|
|
||||||
|
return tomllib.loads(summarization_prompts.read_text())[section]
|
||||||
|
|
||||||
|
|
||||||
|
class BillSummaryError(RuntimeError):
|
||||||
|
"""Raised when a bill summary request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_summary(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
msg = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
logger.info(f"{response.text=}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def build_bill_summary_messages(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build the GPT prompt messages plus compressed text and user prompt."""
|
||||||
|
if not bill_text.text_content:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
compressed_text = compress_bill_text(bill_text.text_content)
|
||||||
|
if not compressed_text:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
user_prompt = summarization_prompts["user_template"].format(
|
||||||
|
text_content=compressed_text
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
|
||||||
|
logger.info(f"{user_prompt_tokens=}")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return messages, user_prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bill_text(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> str:
|
||||||
|
"""Generate and return a summary for one bill_text row."""
|
||||||
|
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
# This may only be for gpt-5.4 mini I need to read the docs
|
||||||
|
if user_prompt_tokens > 272000:
|
||||||
|
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
|
||||||
|
logger.warning(msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
summary = call_openai_summary(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
).strip()
|
||||||
|
if not summary:
|
||||||
|
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_summary_result(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summary: str,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||||
|
bill_text.summary = summary
|
||||||
|
bill_text.summarization_model = model
|
||||||
|
bill_text.summarization_system_prompt_version = "v1.2"
|
||||||
|
bill_text.summarization_user_prompt_version = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bill_texts_for_summarization(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Select bill_text rows that have source text and need summaries."""
|
||||||
|
stmt = (
|
||||||
|
select(BillText)
|
||||||
|
.join(Bill, Bill.id == BillText.bill_id)
|
||||||
|
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||||
|
.options(selectinload(BillText.bill))
|
||||||
|
.order_by(BillText.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill summaries."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
if not getenv("CLOSEDAI_TOKEN"):
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
summarization_prompts = load_summarization_prompts()
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = create_select_bill_texts_for_summarization(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bill_texts = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for index, bill_text in enumerate(bill_texts, 1):
|
||||||
|
summary = summarize_bill_text(
|
||||||
|
model=model,
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
if summary is None:
|
||||||
|
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||||
|
continue
|
||||||
|
store_bill_summary_result(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summary=summary,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
written += 1
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||||
|
|
||||||
|
logger.info("Done: stored %d summaries", written)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
Reference in New Issue
Block a user