From 674edafe94fd39b28d50a2fe55e3003c02d78088 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Tue, 21 Apr 2026 11:44:53 -0400 Subject: [PATCH] created scoring tables and basic logic --- ...islatorscore_and_billtopic_ef4bc5411176.py | 245 +++++++++++ alembic/env.py | 26 +- alembic/script.py.mako | 2 +- .../orm/data_science_dev/congress/__init__.py | 15 +- .../orm/data_science_dev/congress/bill.py | 86 +++- .../data_science_dev/congress/legislator.py | 75 +++- .../orm/data_science_dev/congress/vote.py | 22 +- pipelines/orm/data_science_dev/models.py | 16 +- .../tools/calculate_legislator_scores.py | 394 ++++++++++++++++++ 9 files changed, 843 insertions(+), 38 deletions(-) create mode 100644 alembic/data_science_dev/versions/2026_04_21-adding_legislatorscore_and_billtopic_ef4bc5411176.py create mode 100644 pipelines/tools/calculate_legislator_scores.py diff --git a/alembic/data_science_dev/versions/2026_04_21-adding_legislatorscore_and_billtopic_ef4bc5411176.py b/alembic/data_science_dev/versions/2026_04_21-adding_legislatorscore_and_billtopic_ef4bc5411176.py new file mode 100644 index 0000000..b4991b9 --- /dev/null +++ b/alembic/data_science_dev/versions/2026_04_21-adding_legislatorscore_and_billtopic_ef4bc5411176.py @@ -0,0 +1,245 @@ +"""adding LegislatorScore and BillTopic. + +Revision ID: ef4bc5411176 +Revises: 5cd7eee3549d +Create Date: 2026-04-21 11:35:18.977213 + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlalchemy as sa + +from alembic import op +from pipelines.orm import DataScienceDevBase + +if TYPE_CHECKING: + from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "ef4bc5411176" +down_revision: str | None = "5cd7eee3549d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +schema = DataScienceDevBase.schema_name + + +def upgrade() -> None: + """Upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "bill_topic", + sa.Column("bill_id", sa.Integer(), nullable=False), + sa.Column("topic", sa.String(), nullable=False), + sa.Column( + "support_position", + sa.Enum("for", "against", name="bill_topic_position", native_enum=False), + nullable=False, + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["bill_id"], + [f"{schema}.bill.id"], + name=op.f("fk_bill_topic_bill_id_bill"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_topic")), + sa.UniqueConstraint( + "bill_id", + "topic", + "support_position", + name="uq_bill_topic_bill_id_topic_support_position", + ), + schema=schema, + ) + op.create_index( + "ix_bill_topic_topic", "bill_topic", ["topic"], unique=False, schema=schema + ) + op.create_table( + "legislator_score", + sa.Column("legislator_id", sa.Integer(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("topic", sa.String(), nullable=False), + sa.Column("score", sa.Float(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["legislator_id"], + [f"{schema}.legislator.id"], + name=op.f("fk_legislator_score_legislator_id_legislator"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_score")), + sa.UniqueConstraint( + "legislator_id", + "year", + "topic", + name="uq_legislator_score_legislator_id_year_topic", + ), + schema=schema, + ) + op.create_index( + op.f("ix_legislator_score_legislator_id"), + "legislator_score", + ["legislator_id"], + unique=False, + schema=schema, + ) + op.create_index( + "ix_legislator_score_year_topic", + "legislator_score", + ["year", "topic"], + unique=False, + schema=schema, + ) + op.create_table( + "legislator_bill_score", + sa.Column("bill_id", sa.Integer(), nullable=False), + sa.Column("bill_topic_id", sa.Integer(), nullable=False), + sa.Column("legislator_id", sa.Integer(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("topic", sa.String(), nullable=False), + sa.Column("score", sa.Float(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["bill_id"], + [f"{schema}.bill.id"], + name=op.f("fk_legislator_bill_score_bill_id_bill"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["bill_topic_id"], + [f"{schema}.bill_topic.id"], + name=op.f("fk_legislator_bill_score_bill_topic_id_bill_topic"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["legislator_id"], + [f"{schema}.legislator.id"], + name=op.f("fk_legislator_bill_score_legislator_id_legislator"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_bill_score")), + sa.UniqueConstraint( + "bill_topic_id", + "legislator_id", + "year", + name="uq_legislator_bill_score_bill_topic_id_legislator_id_year", + ), + schema=schema, + ) + op.create_index( + op.f("ix_legislator_bill_score_bill_id"), + "legislator_bill_score", + ["bill_id"], + unique=False, + schema=schema, + ) + op.create_index( + op.f("ix_legislator_bill_score_bill_topic_id"), + "legislator_bill_score", + ["bill_topic_id"], + unique=False, + schema=schema, + ) + op.create_index( + op.f("ix_legislator_bill_score_legislator_id"), + "legislator_bill_score", + ["legislator_id"], + unique=False, + schema=schema, + ) + op.create_index( + "ix_legislator_bill_score_year_topic", + "legislator_bill_score", + ["year", "topic"], + unique=False, + schema=schema, + ) + op.add_column( + "bill", + sa.Column("score_processed_at", sa.DateTime(timezone=True), nullable=True), + schema=schema, + ) + op.add_column( + "bill_text", sa.Column("summary", sa.String(), nullable=True), schema=schema + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("bill_text", "summary", schema=schema) + op.drop_column("bill", "score_processed_at", schema=schema) + op.drop_index( + "ix_legislator_bill_score_year_topic", + table_name="legislator_bill_score", + schema=schema, + ) + op.drop_index( + op.f("ix_legislator_bill_score_legislator_id"), + table_name="legislator_bill_score", + schema=schema, + ) + op.drop_index( + op.f("ix_legislator_bill_score_bill_topic_id"), + table_name="legislator_bill_score", + schema=schema, + ) + op.drop_index( + op.f("ix_legislator_bill_score_bill_id"), + table_name="legislator_bill_score", + schema=schema, + ) + op.drop_table("legislator_bill_score", schema=schema) + op.drop_index( + "ix_legislator_score_year_topic", table_name="legislator_score", schema=schema + ) + op.drop_index( + op.f("ix_legislator_score_legislator_id"), + table_name="legislator_score", + schema=schema, + ) + op.drop_table("legislator_score", schema=schema) + op.drop_index("ix_bill_topic_topic", table_name="bill_topic", schema=schema) + op.drop_table("bill_topic", schema=schema) + # ### end Alembic commands ### diff --git a/alembic/env.py b/alembic/env.py index 85fd3f3..702e7d4 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -40,8 +40,12 @@ def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None: """Dynamic schema.""" original_file = Path(filename).read_text() schema_name = base_class.schema_name - dynamic_schema_file_part1 = original_file.replace(f"schema='{schema_name}'", "schema=schema") - dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{schema_name}.", "f'{schema}.") + dynamic_schema_file_part1 = original_file.replace( + f"schema='{schema_name}'", "schema=schema" + ) + dynamic_schema_file = dynamic_schema_file_part1.replace( + f"'{schema_name}.", "f'{schema}." + ) Path(filename).write_text(dynamic_schema_file) @@ -49,7 +53,10 @@ def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None: def import_postgresql(filename: str, _options: dict[Any, Any]) -> None: """Add postgresql dialect import when postgresql types are used.""" content = Path(filename).read_text() - if "postgresql." in content and "from sqlalchemy.dialects import postgresql" not in content: + if ( + "postgresql." in content + and "from sqlalchemy.dialects import postgresql" not in content + ): content = content.replace( "import sqlalchemy as sa\n", "import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n", @@ -66,8 +73,17 @@ def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None: def include_name( name: str | None, - type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"], - _parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None], + type_: Literal[ + "schema", + "table", + "column", + "index", + "unique_constraint", + "foreign_key_constraint", + ], + _parent_names: MutableMapping[ + Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None + ], ) -> bool: """Filter tables to be included in the migration. diff --git a/alembic/script.py.mako b/alembic/script.py.mako index a9f9996..6fcca21 100644 --- a/alembic/script.py.mako +++ b/alembic/script.py.mako @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING import sqlalchemy as sa from alembic import op -from python.orm import ${config.attributes["base"].__name__} +from pipelines.orm import ${config.attributes["base"].__name__} if TYPE_CHECKING: from collections.abc import Sequence diff --git a/pipelines/orm/data_science_dev/congress/__init__.py b/pipelines/orm/data_science_dev/congress/__init__.py index 366fcda..c2f1630 100644 --- a/pipelines/orm/data_science_dev/congress/__init__.py +++ b/pipelines/orm/data_science_dev/congress/__init__.py @@ -1,8 +1,15 @@ -"""init.""" +"""Congress ORM models.""" -from pipelines.orm.data_science_dev.congress.bill import Bill, BillText +from pipelines.orm.data_science_dev.congress.bill import ( + Bill, + BillText, + BillTopic, + BillTopicPosition, +) from pipelines.orm.data_science_dev.congress.legislator import ( + LegislatorBillScore, Legislator, + LegislatorScore, LegislatorSocialMedia, ) from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord @@ -10,7 +17,11 @@ from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord __all__ = [ "Bill", "BillText", + "BillTopic", + "BillTopicPosition", "Legislator", + "LegislatorBillScore", + "LegislatorScore", "LegislatorSocialMedia", "Vote", "VoteRecord", diff --git a/pipelines/orm/data_science_dev/congress/bill.py b/pipelines/orm/data_science_dev/congress/bill.py index 87c9c8d..e9a11f0 100644 --- a/pipelines/orm/data_science_dev/congress/bill.py +++ b/pipelines/orm/data_science_dev/congress/bill.py @@ -2,22 +2,37 @@ from __future__ import annotations -from datetime import date +from datetime import date, datetime +from enum import StrEnum from typing import TYPE_CHECKING -from sqlalchemy import ForeignKey, Index, UniqueConstraint +from sqlalchemy import DateTime, Enum, ForeignKey, Index, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from pipelines.orm.data_science_dev.base import DataScienceDevTableBase if TYPE_CHECKING: + from pipelines.orm.data_science_dev.congress.legislator import LegislatorBillScore from pipelines.orm.data_science_dev.congress.vote import Vote +class BillTopicPosition(StrEnum): + """Whether a yes vote on a bill is for or against a topic.""" + + FOR = "for" + AGAINST = "against" + + class Bill(DataScienceDevTableBase): """Legislation with congress number, type, titles, status, and sponsor.""" __tablename__ = "bill" + __table_args__ = ( + UniqueConstraint( + "congress", "bill_type", "number", name="uq_bill_congress_type_number" + ), + Index("ix_bill_congress", "congress"), + ) congress: Mapped[int] bill_type: Mapped[str] @@ -33,6 +48,7 @@ class Bill(DataScienceDevTableBase): sponsor_bioguide_id: Mapped[str | None] subjects_top_term: Mapped[str | None] + score_processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) votes: Mapped[list[Vote]] = relationship( "Vote", @@ -43,12 +59,15 @@ class Bill(DataScienceDevTableBase): back_populates="bill", cascade="all, delete-orphan", ) - - __table_args__ = ( - UniqueConstraint( - "congress", "bill_type", "number", name="uq_bill_congress_type_number" - ), - Index("ix_bill_congress", "congress"), + topics: Mapped[list[BillTopic]] = relationship( + "BillTopic", + back_populates="bill", + cascade="all, delete-orphan", + ) + legislator_bill_scores: Mapped[list[LegislatorBillScore]] = relationship( + "LegislatorBillScore", + back_populates="bill", + cascade="all, delete-orphan", ) @@ -56,17 +75,50 @@ class BillText(DataScienceDevTableBase): """Stores different text versions of a bill (introduced, enrolled, etc.).""" __tablename__ = "bill_text" - - bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE")) - version_code: Mapped[str] - version_name: Mapped[str | None] - text_content: Mapped[str | None] - date: Mapped[date | None] - - bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts") - __table_args__ = ( UniqueConstraint( "bill_id", "version_code", name="uq_bill_text_bill_id_version_code" ), ) + + bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE")) + version_code: Mapped[str] + version_name: Mapped[str | None] + text_content: Mapped[str | None] + summary: Mapped[str | None] + date: Mapped[date | None] + + bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts") + + +class BillTopic(DataScienceDevTableBase): + """One bill stance on one topic used to score roll-call votes.""" + + __tablename__ = "bill_topic" + __table_args__ = ( + UniqueConstraint( + "bill_id", + "topic", + "support_position", + name="uq_bill_topic_bill_id_topic_support_position", + ), + Index("ix_bill_topic_topic", "topic"), + ) + + bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE")) + topic: Mapped[str] + support_position: Mapped[BillTopicPosition] = mapped_column( + Enum( + BillTopicPosition, + values_callable=lambda enum_cls: [member.value for member in enum_cls], + native_enum=False, + name="bill_topic_position", + ) + ) + + bill: Mapped[Bill] = relationship("Bill", back_populates="topics") + legislator_bill_scores: Mapped[list[LegislatorBillScore]] = relationship( + "LegislatorBillScore", + back_populates="bill_topic", + cascade="all, delete-orphan", + ) diff --git a/pipelines/orm/data_science_dev/congress/legislator.py b/pipelines/orm/data_science_dev/congress/legislator.py index 3c274b4..474b8f4 100644 --- a/pipelines/orm/data_science_dev/congress/legislator.py +++ b/pipelines/orm/data_science_dev/congress/legislator.py @@ -5,12 +5,13 @@ from __future__ import annotations from datetime import date from typing import TYPE_CHECKING -from sqlalchemy import ForeignKey, Text +from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from pipelines.orm.data_science_dev.base import DataScienceDevTableBase if TYPE_CHECKING: + from pipelines.orm.data_science_dev.congress.bill import Bill, BillTopic from pipelines.orm.data_science_dev.congress.vote import VoteRecord @@ -50,6 +51,16 @@ class Legislator(DataScienceDevTableBase): back_populates="legislator", cascade="all, delete-orphan", ) + scores: Mapped[list[LegislatorScore]] = relationship( + "LegislatorScore", + back_populates="legislator", + cascade="all, delete-orphan", + ) + bill_scores: Mapped[list[LegislatorBillScore]] = relationship( + "LegislatorBillScore", + back_populates="legislator", + cascade="all, delete-orphan", + ) class LegislatorSocialMedia(DataScienceDevTableBase): @@ -66,3 +77,65 @@ class LegislatorSocialMedia(DataScienceDevTableBase): legislator: Mapped[Legislator] = relationship( back_populates="social_media_accounts" ) + + +class LegislatorScore(DataScienceDevTableBase): + """Computed topic score for a legislator in one calendar year.""" + + __tablename__ = "legislator_score" + __table_args__ = ( + UniqueConstraint( + "legislator_id", + "year", + "topic", + name="uq_legislator_score_legislator_id_year_topic", + ), + Index("ix_legislator_score_year_topic", "year", "topic"), + ) + + legislator_id: Mapped[int] = mapped_column( + ForeignKey("main.legislator.id", ondelete="CASCADE"), + index=True, + ) + year: Mapped[int] + topic: Mapped[str] + score: Mapped[float] + + legislator: Mapped[Legislator] = relationship(back_populates="scores") + + + +class LegislatorBillScore(DataScienceDevTableBase): + """Per-bill source score used to maintain aggregate legislator scores.""" + + __tablename__ = "legislator_bill_score" + __table_args__ = ( + UniqueConstraint( + "bill_topic_id", + "legislator_id", + "year", + name="uq_legislator_bill_score_bill_topic_id_legislator_id_year", + ), + Index("ix_legislator_bill_score_year_topic", "year", "topic"), + ) + + bill_id: Mapped[int] = mapped_column( + ForeignKey("main.bill.id", ondelete="CASCADE"), + index=True, + ) + bill_topic_id: Mapped[int] = mapped_column( + ForeignKey("main.bill_topic.id", ondelete="CASCADE"), + index=True, + ) + legislator_id: Mapped[int] = mapped_column( + ForeignKey("main.legislator.id", ondelete="CASCADE"), + index=True, + ) + year: Mapped[int] + topic: Mapped[str] + score: Mapped[float] + + bill: Mapped[Bill] = relationship(back_populates="legislator_bill_scores") + bill_topic: Mapped[BillTopic] = relationship(back_populates="legislator_bill_scores") + legislator: Mapped[Legislator] = relationship(back_populates="bill_scores") + diff --git a/pipelines/orm/data_science_dev/congress/vote.py b/pipelines/orm/data_science_dev/congress/vote.py index ce4de3f..3e67f24 100644 --- a/pipelines/orm/data_science_dev/congress/vote.py +++ b/pipelines/orm/data_science_dev/congress/vote.py @@ -44,6 +44,17 @@ class Vote(DataScienceDevTableBase): """Roll call votes with counts and optional bill linkage.""" __tablename__ = "vote" + __table_args__ = ( + UniqueConstraint( + "congress", + "chamber", + "session", + "number", + name="uq_vote_congress_chamber_session_number", + ), + Index("ix_vote_date", "vote_date"), + Index("ix_vote_congress_chamber", "congress", "chamber"), + ) congress: Mapped[int] chamber: Mapped[str] @@ -71,14 +82,3 @@ class Vote(DataScienceDevTableBase): cascade="all, delete-orphan", ) - __table_args__ = ( - UniqueConstraint( - "congress", - "chamber", - "session", - "number", - name="uq_vote_congress_chamber_session_number", - ), - Index("ix_vote_date", "vote_date"), - Index("ix_vote_congress_chamber", "congress", "chamber"), - ) diff --git a/pipelines/orm/data_science_dev/models.py b/pipelines/orm/data_science_dev/models.py index 4264c5c..671b877 100644 --- a/pipelines/orm/data_science_dev/models.py +++ b/pipelines/orm/data_science_dev/models.py @@ -2,14 +2,28 @@ from __future__ import annotations -from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord +from pipelines.orm.data_science_dev.congress import ( + Bill, + BillText, + BillTopic, + BillTopicPosition, + Legislator, + LegislatorBillScore, + LegislatorScore, + Vote, + VoteRecord, +) from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata from pipelines.orm.data_science_dev.posts.tables import Posts __all__ = [ "Bill", "BillText", + "BillTopic", + "BillTopicPosition", "Legislator", + "LegislatorBillScore", + "LegislatorScore", "Posts", "Vote", "VoteRecord", diff --git a/pipelines/tools/calculate_legislator_scores.py b/pipelines/tools/calculate_legislator_scores.py new file mode 100644 index 0000000..82d7327 --- /dev/null +++ b/pipelines/tools/calculate_legislator_scores.py @@ -0,0 +1,394 @@ +"""Calculate legislator topic scores from bill topic metadata and roll-call votes.""" + +from __future__ import annotations + +import argparse +from collections import defaultdict +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Iterable + +from sqlalchemy import Integer, delete, extract, func, select, tuple_ +from sqlalchemy.orm import Session + +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.congress import ( + Bill, + BillTopic, + BillTopicPosition, + LegislatorBillScore, + LegislatorScore, + Vote, + VoteRecord, +) + +SUPPORT_VOTES = frozenset({"yea", "aye", "yes"}) +OPPOSE_VOTES = frozenset({"nay", "no"}) +NEUTRAL_SCORE = 50.0 +SUPPORT_SCORE = 100.0 +OPPOSE_SCORE = 1.0 +ScoreKey = tuple[int, int, str] + + +@dataclass(frozen=True) +class VoteScoreInput: + """Raw vote data needed for one bill/topic/legislator scoring event.""" + + bill_id: int + bill_topic_id: int + legislator_id: int + year: int + topic: str + support_position: str | BillTopicPosition + vote_position: str | None + + +@dataclass(frozen=True) +class ComputedBillScore: + """Per-bill source score for one legislator/year/topic.""" + + bill_id: int + bill_topic_id: int + legislator_id: int + year: int + topic: str + score: float + + +@dataclass(frozen=True) +class ScoreRunResult: + """Summary for a scoring job run.""" + + processed_bills: int + bill_score_rows: int + aggregate_score_rows: int + + +def score_vote( + vote_position: str | None, + support_position: str | BillTopicPosition | None, +) -> float | None: + """Return a 1-100 score where 50 is neutral.""" + stance = normalize_support_position(support_position) + if stance is None: + return None + if vote_position is None: + return NEUTRAL_SCORE + + vote = vote_position.strip().casefold() + if vote not in SUPPORT_VOTES | OPPOSE_VOTES: + return NEUTRAL_SCORE + + voted_yes = vote in SUPPORT_VOTES + yes_is_for_topic = stance is BillTopicPosition.FOR + return SUPPORT_SCORE if voted_yes == yes_is_for_topic else OPPOSE_SCORE + + +def normalize_support_position( + support_position: str | BillTopicPosition | None, +) -> BillTopicPosition | None: + """Normalize a DB enum/string stance value.""" + if support_position is None: + return None + if isinstance(support_position, BillTopicPosition): + return support_position + value = support_position.strip().casefold() + try: + return BillTopicPosition(value) + except ValueError: + return None + + +def calculate_bill_score_values( + vote_inputs: Iterable[VoteScoreInput], +) -> list[ComputedBillScore]: + """Aggregate raw vote inputs into per-bill source scores.""" + grouped: dict[tuple[int, int, int, int, str], list[float]] = defaultdict(list) + for vote_input in vote_inputs: + score = score_vote(vote_input.vote_position, vote_input.support_position) + if score is None: + continue + key = ( + vote_input.bill_id, + vote_input.bill_topic_id, + vote_input.legislator_id, + vote_input.year, + vote_input.topic, + ) + grouped[key].append(score) + + return [ + ComputedBillScore( + bill_id=bill_id, + bill_topic_id=bill_topic_id, + legislator_id=legislator_id, + year=year, + topic=topic, + score=sum(scores) / len(scores), + ) + for (bill_id, bill_topic_id, legislator_id, year, topic), scores in sorted( + grouped.items() + ) + ] + + +def calculate_and_store_legislator_scores( + session: Session, + *, + congress: int | None = None, + bill_ids: list[int] | None = None, + topics: list[str] | None = None, + force: bool = False, + limit: int | None = None, +) -> ScoreRunResult: + """Score selected bills and refresh aggregate legislator score rows.""" + selected_bill_ids = select_bill_ids_to_score( + session, + congress=congress, + bill_ids=bill_ids, + topics=topics, + force=force, + limit=limit, + ) + result = ScoreRunResult( + processed_bills=0, + bill_score_rows=0, + aggregate_score_rows=0, + ) + for bill_id in selected_bill_ids: + bill_score_rows, aggregate_score_rows = score_bill( + session, + bill_id=bill_id, + topics=topics, + mark_processed=topics is None, + ) + result = ScoreRunResult( + processed_bills=result.processed_bills + 1, + bill_score_rows=result.bill_score_rows + bill_score_rows, + aggregate_score_rows=result.aggregate_score_rows + aggregate_score_rows, + ) + session.commit() + return result + + +def select_bill_ids_to_score( + session: Session, + *, + congress: int | None = None, + bill_ids: list[int] | None = None, + topics: list[str] | None = None, + force: bool = False, + limit: int | None = None, +) -> list[int]: + """Select bills with topic metadata and votes that should be scored.""" + stmt = ( + select(Bill.id) + .join(BillTopic, BillTopic.bill_id == Bill.id) + .join(Vote, Vote.bill_id == Bill.id) + .distinct() + .order_by(Bill.id) + ) + if not force: + stmt = stmt.where(Bill.score_processed_at.is_(None)) + if congress is not None: + stmt = stmt.where(Bill.congress == congress) + if bill_ids: + stmt = stmt.where(Bill.id.in_(bill_ids)) + if topics: + stmt = stmt.where(BillTopic.topic.in_(topics)) + if limit is not None: + stmt = stmt.limit(limit) + return list(session.scalars(stmt)) + + +def score_bill( + session: Session, + *, + bill_id: int, + topics: list[str] | None = None, + mark_processed: bool = True, +) -> tuple[int, int]: + """Score all selected vote records for one bill and refresh aggregates.""" + prior_keys = _existing_score_keys_for_bill(session, bill_id=bill_id, topics=topics) + session.execute(_delete_bill_scores_statement(bill_id=bill_id, topics=topics)) + session.flush() + + scores = calculate_bill_score_values( + _load_bill_vote_score_inputs(session, bill_id=bill_id, topics=topics) + ) + session.add_all( + LegislatorBillScore( + bill_id=score.bill_id, + bill_topic_id=score.bill_topic_id, + legislator_id=score.legislator_id, + year=score.year, + topic=score.topic, + score=score.score, + ) + for score in scores + ) + if mark_processed: + bill = session.get(Bill, bill_id) + if bill is not None: + bill.score_processed_at = datetime.now(tz=UTC) + session.flush() + + affected_keys = prior_keys | { + (score.legislator_id, score.year, score.topic) for score in scores + } + aggregate_rows = refresh_aggregate_scores(session, affected_keys) + return len(scores), aggregate_rows + + +def refresh_aggregate_scores(session: Session, keys: set[ScoreKey]) -> int: + """Refresh aggregate legislator_score rows from per-bill source scores.""" + if not keys: + return 0 + + key_tuple = tuple_( + LegislatorScore.legislator_id, + LegislatorScore.year, + LegislatorScore.topic, + ) + session.execute(delete(LegislatorScore).where(key_tuple.in_(list(keys)))) + session.flush() + + source_key_tuple = tuple_( + LegislatorBillScore.legislator_id, + LegislatorBillScore.year, + LegislatorBillScore.topic, + ) + rows = session.execute( + select( + LegislatorBillScore.legislator_id, + LegislatorBillScore.year, + LegislatorBillScore.topic, + func.avg(LegislatorBillScore.score).label("score"), + ) + .where(source_key_tuple.in_(list(keys))) + .group_by( + LegislatorBillScore.legislator_id, + LegislatorBillScore.year, + LegislatorBillScore.topic, + ) + ).all() + session.add_all( + LegislatorScore( + legislator_id=row.legislator_id, + year=row.year, + topic=row.topic, + score=float(row.score), + ) + for row in rows + ) + session.flush() + return len(rows) + + +def _load_bill_vote_score_inputs( + session: Session, + *, + bill_id: int, + topics: list[str] | None, +) -> list[VoteScoreInput]: + year = extract("year", Vote.vote_date).cast(Integer).label("year") + stmt = ( + select( + Vote.bill_id, + BillTopic.id.label("bill_topic_id"), + VoteRecord.legislator_id, + year, + BillTopic.topic, + BillTopic.support_position, + VoteRecord.position, + ) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(BillTopic, BillTopic.bill_id == Vote.bill_id) + .where(Vote.bill_id == bill_id) + ) + if topics: + stmt = stmt.where(BillTopic.topic.in_(topics)) + + return [ + VoteScoreInput( + bill_id=row.bill_id, + bill_topic_id=row.bill_topic_id, + legislator_id=row.legislator_id, + year=int(row.year), + topic=row.topic, + support_position=row.support_position, + vote_position=row.position, + ) + for row in session.execute(stmt) + ] + + +def _existing_score_keys_for_bill( + session: Session, + *, + bill_id: int, + topics: list[str] | None, +) -> set[ScoreKey]: + stmt = select( + LegislatorBillScore.legislator_id, + LegislatorBillScore.year, + LegislatorBillScore.topic, + ).where(LegislatorBillScore.bill_id == bill_id) + if topics: + stmt = stmt.where(LegislatorBillScore.topic.in_(topics)) + return {(row.legislator_id, row.year, row.topic) for row in session.execute(stmt)} + + +def _delete_bill_scores_statement(*, bill_id: int, topics: list[str] | None): + stmt = delete(LegislatorBillScore).where(LegislatorBillScore.bill_id == bill_id) + if topics: + stmt = stmt.where(LegislatorBillScore.topic.in_(topics)) + return stmt + + +def main() -> None: + """CLI entrypoint.""" + parser = argparse.ArgumentParser( + description="Calculate legislator_score rows from bill_topic and vote_record data." + ) + parser.add_argument("--congress", type=int, help="Only score bills from one Congress.") + parser.add_argument( + "--bill-id", + action="append", + dest="bill_ids", + type=int, + help="Only score one bill id. Repeat for multiple bills.", + ) + parser.add_argument( + "--topic", + action="append", + dest="topics", + help="Only calculate one topic. Repeat for multiple topics.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Reprocess bills even when bill.score_processed_at is already set.", + ) + parser.add_argument("--limit", type=int, help="Maximum number of bills to process.") + args = parser.parse_args() + + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + with Session(engine) as session: + result = calculate_and_store_legislator_scores( + session, + congress=args.congress, + bill_ids=args.bill_ids, + topics=args.topics, + force=args.force, + limit=args.limit, + ) + print( + "Processed " + f"{result.processed_bills} bills; stored {result.bill_score_rows} bill score rows; " + f"refreshed {result.aggregate_score_rows} aggregate score rows." + ) + + +if __name__ == "__main__": + main()