created scoring tables and basic logic

This commit is contained in:
2026-04-21 11:44:53 -04:00
parent be4b473a3c
commit 674edafe94
9 changed files with 843 additions and 38 deletions
@@ -0,0 +1,394 @@
"""Calculate legislator topic scores from bill topic metadata and roll-call votes."""
from __future__ import annotations
import argparse
from collections import defaultdict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Iterable
from sqlalchemy import Integer, delete, extract, func, select, tuple_
from sqlalchemy.orm import Session
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillTopic,
BillTopicPosition,
LegislatorBillScore,
LegislatorScore,
Vote,
VoteRecord,
)
SUPPORT_VOTES = frozenset({"yea", "aye", "yes"})
OPPOSE_VOTES = frozenset({"nay", "no"})
NEUTRAL_SCORE = 50.0
SUPPORT_SCORE = 100.0
OPPOSE_SCORE = 1.0
ScoreKey = tuple[int, int, str]
@dataclass(frozen=True)
class VoteScoreInput:
"""Raw vote data needed for one bill/topic/legislator scoring event."""
bill_id: int
bill_topic_id: int
legislator_id: int
year: int
topic: str
support_position: str | BillTopicPosition
vote_position: str | None
@dataclass(frozen=True)
class ComputedBillScore:
"""Per-bill source score for one legislator/year/topic."""
bill_id: int
bill_topic_id: int
legislator_id: int
year: int
topic: str
score: float
@dataclass(frozen=True)
class ScoreRunResult:
"""Summary for a scoring job run."""
processed_bills: int
bill_score_rows: int
aggregate_score_rows: int
def score_vote(
vote_position: str | None,
support_position: str | BillTopicPosition | None,
) -> float | None:
"""Return a 1-100 score where 50 is neutral."""
stance = normalize_support_position(support_position)
if stance is None:
return None
if vote_position is None:
return NEUTRAL_SCORE
vote = vote_position.strip().casefold()
if vote not in SUPPORT_VOTES | OPPOSE_VOTES:
return NEUTRAL_SCORE
voted_yes = vote in SUPPORT_VOTES
yes_is_for_topic = stance is BillTopicPosition.FOR
return SUPPORT_SCORE if voted_yes == yes_is_for_topic else OPPOSE_SCORE
def normalize_support_position(
support_position: str | BillTopicPosition | None,
) -> BillTopicPosition | None:
"""Normalize a DB enum/string stance value."""
if support_position is None:
return None
if isinstance(support_position, BillTopicPosition):
return support_position
value = support_position.strip().casefold()
try:
return BillTopicPosition(value)
except ValueError:
return None
def calculate_bill_score_values(
vote_inputs: Iterable[VoteScoreInput],
) -> list[ComputedBillScore]:
"""Aggregate raw vote inputs into per-bill source scores."""
grouped: dict[tuple[int, int, int, int, str], list[float]] = defaultdict(list)
for vote_input in vote_inputs:
score = score_vote(vote_input.vote_position, vote_input.support_position)
if score is None:
continue
key = (
vote_input.bill_id,
vote_input.bill_topic_id,
vote_input.legislator_id,
vote_input.year,
vote_input.topic,
)
grouped[key].append(score)
return [
ComputedBillScore(
bill_id=bill_id,
bill_topic_id=bill_topic_id,
legislator_id=legislator_id,
year=year,
topic=topic,
score=sum(scores) / len(scores),
)
for (bill_id, bill_topic_id, legislator_id, year, topic), scores in sorted(
grouped.items()
)
]
def calculate_and_store_legislator_scores(
session: Session,
*,
congress: int | None = None,
bill_ids: list[int] | None = None,
topics: list[str] | None = None,
force: bool = False,
limit: int | None = None,
) -> ScoreRunResult:
"""Score selected bills and refresh aggregate legislator score rows."""
selected_bill_ids = select_bill_ids_to_score(
session,
congress=congress,
bill_ids=bill_ids,
topics=topics,
force=force,
limit=limit,
)
result = ScoreRunResult(
processed_bills=0,
bill_score_rows=0,
aggregate_score_rows=0,
)
for bill_id in selected_bill_ids:
bill_score_rows, aggregate_score_rows = score_bill(
session,
bill_id=bill_id,
topics=topics,
mark_processed=topics is None,
)
result = ScoreRunResult(
processed_bills=result.processed_bills + 1,
bill_score_rows=result.bill_score_rows + bill_score_rows,
aggregate_score_rows=result.aggregate_score_rows + aggregate_score_rows,
)
session.commit()
return result
def select_bill_ids_to_score(
session: Session,
*,
congress: int | None = None,
bill_ids: list[int] | None = None,
topics: list[str] | None = None,
force: bool = False,
limit: int | None = None,
) -> list[int]:
"""Select bills with topic metadata and votes that should be scored."""
stmt = (
select(Bill.id)
.join(BillTopic, BillTopic.bill_id == Bill.id)
.join(Vote, Vote.bill_id == Bill.id)
.distinct()
.order_by(Bill.id)
)
if not force:
stmt = stmt.where(Bill.score_processed_at.is_(None))
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(Bill.id.in_(bill_ids))
if topics:
stmt = stmt.where(BillTopic.topic.in_(topics))
if limit is not None:
stmt = stmt.limit(limit)
return list(session.scalars(stmt))
def score_bill(
session: Session,
*,
bill_id: int,
topics: list[str] | None = None,
mark_processed: bool = True,
) -> tuple[int, int]:
"""Score all selected vote records for one bill and refresh aggregates."""
prior_keys = _existing_score_keys_for_bill(session, bill_id=bill_id, topics=topics)
session.execute(_delete_bill_scores_statement(bill_id=bill_id, topics=topics))
session.flush()
scores = calculate_bill_score_values(
_load_bill_vote_score_inputs(session, bill_id=bill_id, topics=topics)
)
session.add_all(
LegislatorBillScore(
bill_id=score.bill_id,
bill_topic_id=score.bill_topic_id,
legislator_id=score.legislator_id,
year=score.year,
topic=score.topic,
score=score.score,
)
for score in scores
)
if mark_processed:
bill = session.get(Bill, bill_id)
if bill is not None:
bill.score_processed_at = datetime.now(tz=UTC)
session.flush()
affected_keys = prior_keys | {
(score.legislator_id, score.year, score.topic) for score in scores
}
aggregate_rows = refresh_aggregate_scores(session, affected_keys)
return len(scores), aggregate_rows
def refresh_aggregate_scores(session: Session, keys: set[ScoreKey]) -> int:
"""Refresh aggregate legislator_score rows from per-bill source scores."""
if not keys:
return 0
key_tuple = tuple_(
LegislatorScore.legislator_id,
LegislatorScore.year,
LegislatorScore.topic,
)
session.execute(delete(LegislatorScore).where(key_tuple.in_(list(keys))))
session.flush()
source_key_tuple = tuple_(
LegislatorBillScore.legislator_id,
LegislatorBillScore.year,
LegislatorBillScore.topic,
)
rows = session.execute(
select(
LegislatorBillScore.legislator_id,
LegislatorBillScore.year,
LegislatorBillScore.topic,
func.avg(LegislatorBillScore.score).label("score"),
)
.where(source_key_tuple.in_(list(keys)))
.group_by(
LegislatorBillScore.legislator_id,
LegislatorBillScore.year,
LegislatorBillScore.topic,
)
).all()
session.add_all(
LegislatorScore(
legislator_id=row.legislator_id,
year=row.year,
topic=row.topic,
score=float(row.score),
)
for row in rows
)
session.flush()
return len(rows)
def _load_bill_vote_score_inputs(
session: Session,
*,
bill_id: int,
topics: list[str] | None,
) -> list[VoteScoreInput]:
year = extract("year", Vote.vote_date).cast(Integer).label("year")
stmt = (
select(
Vote.bill_id,
BillTopic.id.label("bill_topic_id"),
VoteRecord.legislator_id,
year,
BillTopic.topic,
BillTopic.support_position,
VoteRecord.position,
)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(BillTopic, BillTopic.bill_id == Vote.bill_id)
.where(Vote.bill_id == bill_id)
)
if topics:
stmt = stmt.where(BillTopic.topic.in_(topics))
return [
VoteScoreInput(
bill_id=row.bill_id,
bill_topic_id=row.bill_topic_id,
legislator_id=row.legislator_id,
year=int(row.year),
topic=row.topic,
support_position=row.support_position,
vote_position=row.position,
)
for row in session.execute(stmt)
]
def _existing_score_keys_for_bill(
session: Session,
*,
bill_id: int,
topics: list[str] | None,
) -> set[ScoreKey]:
stmt = select(
LegislatorBillScore.legislator_id,
LegislatorBillScore.year,
LegislatorBillScore.topic,
).where(LegislatorBillScore.bill_id == bill_id)
if topics:
stmt = stmt.where(LegislatorBillScore.topic.in_(topics))
return {(row.legislator_id, row.year, row.topic) for row in session.execute(stmt)}
def _delete_bill_scores_statement(*, bill_id: int, topics: list[str] | None):
stmt = delete(LegislatorBillScore).where(LegislatorBillScore.bill_id == bill_id)
if topics:
stmt = stmt.where(LegislatorBillScore.topic.in_(topics))
return stmt
def main() -> None:
"""CLI entrypoint."""
parser = argparse.ArgumentParser(
description="Calculate legislator_score rows from bill_topic and vote_record data."
)
parser.add_argument("--congress", type=int, help="Only score bills from one Congress.")
parser.add_argument(
"--bill-id",
action="append",
dest="bill_ids",
type=int,
help="Only score one bill id. Repeat for multiple bills.",
)
parser.add_argument(
"--topic",
action="append",
dest="topics",
help="Only calculate one topic. Repeat for multiple topics.",
)
parser.add_argument(
"--force",
action="store_true",
help="Reprocess bills even when bill.score_processed_at is already set.",
)
parser.add_argument("--limit", type=int, help="Maximum number of bills to process.")
args = parser.parse_args()
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
result = calculate_and_store_legislator_scores(
session,
congress=args.congress,
bill_ids=args.bill_ids,
topics=args.topics,
force=args.force,
limit=args.limit,
)
print(
"Processed "
f"{result.processed_bills} bills; stored {result.bill_score_rows} bill score rows; "
f"refreshed {result.aggregate_score_rows} aggregate score rows."
)
if __name__ == "__main__":
main()