Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 36718bbce0 | |||
| f33a5c2233 | |||
| 2facb82bd4 | |||
| 8d5a6e202b | |||
| f32c895561 | |||
| 09f7f0187f | |||
| 3056c19f69 | |||
| 88ec8015ba | |||
| 3f397f9bee | |||
| 242e5123ac |
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,60 @@
|
|||||||
|
"""adding FailedIngestion.
|
||||||
|
|
||||||
|
Revision ID: 2f43120e3ffc
|
||||||
|
Revises: f99be864fe69
|
||||||
|
Create Date: 2026-03-24 23:46:17.277897
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2f43120e3ffc"
|
||||||
|
down_revision: str | None = "f99be864fe69"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"failed_ingestion",
|
||||||
|
sa.Column("raw_line", sa.Text(), nullable=False),
|
||||||
|
sa.Column("error", sa.Text(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("failed_ingestion", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
File diff suppressed because it is too large
Load Diff
+1391
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,79 @@
|
|||||||
|
"""Attach all partition tables to the posts parent table.
|
||||||
|
|
||||||
|
Alembic autogenerate creates partition tables as standalone tables but does not
|
||||||
|
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
|
||||||
|
route inserts to the correct partition.
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 605b1794838f
|
||||||
|
Create Date: 2026-03-25 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
from pipelines.orm.data_science_dev.posts.partitions import (
|
||||||
|
PARTITION_END_YEAR,
|
||||||
|
PARTITION_START_YEAR,
|
||||||
|
iso_weeks_in_year,
|
||||||
|
week_bounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: str | None = "605b1794838f"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
ALREADY_ATTACHED_QUERY = text("""
|
||||||
|
SELECT inhrelid::regclass::text
|
||||||
|
FROM pg_inherits
|
||||||
|
WHERE inhparent = :parent::regclass
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Attach all weekly partition tables to the posts parent table."""
|
||||||
|
connection = op.get_bind()
|
||||||
|
already_attached = {
|
||||||
|
row[0]
|
||||||
|
for row in connection.execute(
|
||||||
|
ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
qualified_name = f"{schema}.{table_name}"
|
||||||
|
if qualified_name in already_attached:
|
||||||
|
continue
|
||||||
|
start, end = week_bounds(year, week)
|
||||||
|
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {schema}.posts "
|
||||||
|
f"ATTACH PARTITION {qualified_name} "
|
||||||
|
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Detach all weekly partition tables from the posts parent table."""
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,229 @@
|
|||||||
|
"""adding congress data.
|
||||||
|
|
||||||
|
Revision ID: 83bfc8af92d8
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2026-03-27 10:43:02.324510
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "83bfc8af92d8"
|
||||||
|
down_revision: str | None = "a1b2c3d4e5f6"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"bill",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("bill_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("title", sa.String(), nullable=True),
|
||||||
|
sa.Column("title_short", sa.String(), nullable=True),
|
||||||
|
sa.Column("official_title", sa.String(), nullable=True),
|
||||||
|
sa.Column("status", sa.String(), nullable=True),
|
||||||
|
sa.Column("status_at", sa.Date(), nullable=True),
|
||||||
|
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("subjects_top_term", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_congress", "bill", ["congress"], unique=False, schema=schema
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator",
|
||||||
|
sa.Column("bioguide_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("thomas_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("lis_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("govtrack_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("opensecrets_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("fec_ids", sa.String(), nullable=True),
|
||||||
|
sa.Column("first_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("last_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("official_full_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("nickname", sa.String(), nullable=True),
|
||||||
|
sa.Column("birthday", sa.Date(), nullable=True),
|
||||||
|
sa.Column("gender", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_party", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_state", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_district", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("current_chamber", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bioguide_id"),
|
||||||
|
"legislator",
|
||||||
|
["bioguide_id"],
|
||||||
|
unique=True,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("version_code", sa.String(), nullable=False),
|
||||||
|
sa.Column("version_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("text_content", sa.String(), nullable=True),
|
||||||
|
sa.Column("date", sa.Date(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_text_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("session", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("vote_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("question", sa.String(), nullable=True),
|
||||||
|
sa.Column("result", sa.String(), nullable=True),
|
||||||
|
sa.Column("result_text", sa.String(), nullable=True),
|
||||||
|
sa.Column("vote_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("yea_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("nay_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("not_voting_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("present_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session",
|
||||||
|
"number",
|
||||||
|
name="uq_vote_congress_chamber_session_number",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_congress_chamber",
|
||||||
|
"vote",
|
||||||
|
["congress", "chamber"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
|
||||||
|
op.create_table(
|
||||||
|
"vote_record",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("position", sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_vote_record_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_record_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint(
|
||||||
|
"vote_id", "legislator_id", name=op.f("pk_vote_record")
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("vote_record", schema=schema)
|
||||||
|
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
|
||||||
|
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
|
||||||
|
op.drop_table("vote", schema=schema)
|
||||||
|
op.drop_table("bill_text", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_table("legislator", schema=schema)
|
||||||
|
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
|
||||||
|
op.drop_table("bill", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+68
@@ -0,0 +1,68 @@
|
|||||||
|
"""adding LegislatorSocialMedia.
|
||||||
|
|
||||||
|
Revision ID: 5cd7eee3549d
|
||||||
|
Revises: 83bfc8af92d8
|
||||||
|
Create Date: 2026-03-29 11:53:44.224799
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "5cd7eee3549d"
|
||||||
|
down_revision: str | None = "83bfc8af92d8"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_social_media",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("platform", sa.String(), nullable=False),
|
||||||
|
sa.Column("account_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("url", sa.String(), nullable=True),
|
||||||
|
sa.Column("source", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("legislator_social_media", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+245
@@ -0,0 +1,245 @@
|
|||||||
|
"""adding LegislatorScore and BillTopic.
|
||||||
|
|
||||||
|
Revision ID: ef4bc5411176
|
||||||
|
Revises: 5cd7eee3549d
|
||||||
|
Create Date: 2026-04-21 11:35:18.977213
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "ef4bc5411176"
|
||||||
|
down_revision: str | None = "5cd7eee3549d"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"bill_topic",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"support_position",
|
||||||
|
sa.Enum("for", "against", name="bill_topic_position", native_enum=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_topic_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_topic")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_id",
|
||||||
|
"topic",
|
||||||
|
"support_position",
|
||||||
|
name="uq_bill_topic_bill_id_topic_support_position",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_topic_topic", "bill_topic", ["topic"], unique=False, schema=schema
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator_score",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("year", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column("score", sa.Float(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_score_legislator_id"),
|
||||||
|
"legislator_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_score_year_topic",
|
||||||
|
"legislator_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator_bill_score",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("bill_topic_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("year", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column("score", sa.Float(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_topic_id"],
|
||||||
|
[f"{schema}.bill_topic.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_topic_id_bill_topic"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_bill_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_topic_id",
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
name="uq_legislator_bill_score_bill_topic_id_legislator_id_year",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_topic_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_bill_score_year_topic",
|
||||||
|
"legislator_bill_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill",
|
||||||
|
sa.Column("score_processed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text", sa.Column("summary", sa.String(), nullable=True), schema=schema
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("bill_text", "summary", schema=schema)
|
||||||
|
op.drop_column("bill", "score_processed_at", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_legislator_bill_score_year_topic",
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_bill_score", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_legislator_score_year_topic", table_name="legislator_score", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_score_legislator_id"),
|
||||||
|
table_name="legislator_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_score", schema=schema)
|
||||||
|
op.drop_index("ix_bill_topic_topic", table_name="bill_topic", schema=schema)
|
||||||
|
op.drop_table("bill_topic", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+146
@@ -0,0 +1,146 @@
|
|||||||
|
"""removed LegislatorBillScore.
|
||||||
|
|
||||||
|
Revision ID: b63ed11d6775
|
||||||
|
Revises: 7d15f9b7c8a2
|
||||||
|
Create Date: 2026-04-21 22:46:48.058542
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b63ed11d6775"
|
||||||
|
down_revision: str | None = "7d15f9b7c8a2"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_year_topic"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_bill_score", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_bill_score",
|
||||||
|
sa.Column("bill_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("bill_topic_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("year", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("topic", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"score",
|
||||||
|
sa.DOUBLE_PRECISION(precision=53),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
postgresql.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
postgresql.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_topic_id"],
|
||||||
|
[f"{schema}.bill_topic.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_topic_id_bill_topic"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_bill_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_topic_id",
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
name=op.f("uq_legislator_bill_score_bill_topic_id_legislator_id_year"),
|
||||||
|
postgresql_include=[],
|
||||||
|
postgresql_nulls_not_distinct=False,
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_year_topic"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_topic_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+54
@@ -0,0 +1,54 @@
|
|||||||
|
"""add bill_text summarization metadata.
|
||||||
|
|
||||||
|
Revision ID: 7d15f9b7c8a2
|
||||||
|
Revises: ef4bc5411176
|
||||||
|
Create Date: 2026-04-22 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "7d15f9b7c8a2"
|
||||||
|
down_revision: str | None = "ef4bc5411176"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_model", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
op.drop_column(
|
||||||
|
"bill_text", "summarization_system_prompt_version", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_column("bill_text", "summarization_user_prompt_version", schema=schema)
|
||||||
|
op.drop_column("bill_text", "summarization_model", schema=schema)
|
||||||
+98
@@ -0,0 +1,98 @@
|
|||||||
|
"""adding LegislatorScoreFake.
|
||||||
|
|
||||||
|
Revision ID: 06f833813bd7
|
||||||
|
Revises: b63ed11d6775
|
||||||
|
Create Date: 2026-04-22 18:41:07.484609
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "06f833813bd7"
|
||||||
|
down_revision: str | None = "b63ed11d6775"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_score_fake",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("year", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column("score", sa.Float(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_score_fake_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_score_fake")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_fake_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_score_fake_legislator_id"),
|
||||||
|
"legislator_score_fake",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_score_fake_year_topic",
|
||||||
|
"legislator_score_fake",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(
|
||||||
|
"ix_legislator_score_fake_year_topic",
|
||||||
|
table_name="legislator_score_fake",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_score_fake_legislator_id"),
|
||||||
|
table_name="legislator_score_fake",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_score_fake", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""add vote.bill_text_id linkage.
|
||||||
|
|
||||||
|
Revision ID: 9c7d4a2e1b10
|
||||||
|
Revises: 06f833813bd7
|
||||||
|
Create Date: 2026-04-23 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "9c7d4a2e1b10"
|
||||||
|
down_revision: str | None = "06f833813bd7"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column("bill_text_id", sa.Integer(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_bill_text_id",
|
||||||
|
"vote",
|
||||||
|
["bill_text_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_vote_bill_text_id_bill_text",
|
||||||
|
"vote",
|
||||||
|
"bill_text",
|
||||||
|
["bill_text_id"],
|
||||||
|
["id"],
|
||||||
|
source_schema=schema,
|
||||||
|
referent_schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_vote_bill_text_id_bill_text",
|
||||||
|
"vote",
|
||||||
|
schema=schema,
|
||||||
|
type_="foreignkey",
|
||||||
|
)
|
||||||
|
op.drop_index("ix_vote_bill_text_id", table_name="vote", schema=schema)
|
||||||
|
op.drop_column("vote", "bill_text_id", schema=schema)
|
||||||
+844
@@ -0,0 +1,844 @@
|
|||||||
|
"""canonical vote context v3.
|
||||||
|
|
||||||
|
Revision ID: 1f8c0e7a9d21
|
||||||
|
Revises: 9c7d4a2e1b10
|
||||||
|
Create Date: 2026-04-25 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "1f8c0e7a9d21"
|
||||||
|
down_revision: str | None = "9c7d4a2e1b10"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
op.create_table(
|
||||||
|
"ingest_run",
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("git_sha", sa.String(), nullable=True),
|
||||||
|
sa.Column("classifier_version", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_snapshot_label", sa.String(), nullable=True),
|
||||||
|
sa.Column("status", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_ingest_run")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"source_artifact",
|
||||||
|
sa.Column("source_kind", sa.String(), nullable=False),
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=True),
|
||||||
|
sa.Column("local_path", sa.String(), nullable=False),
|
||||||
|
sa.Column("source_url", sa.String(), nullable=True),
|
||||||
|
sa.Column("sha256", sa.String(), nullable=False),
|
||||||
|
sa.Column("byte_size", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("modified_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("ingested_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("ingest_run_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["ingest_run_id"],
|
||||||
|
[f"{schema}.ingest_run.id"],
|
||||||
|
name=op.f("fk_source_artifact_ingest_run_id_ingest_run"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_source_artifact")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_source_artifact_source_kind",
|
||||||
|
"source_artifact",
|
||||||
|
["source_kind"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_source_artifact_congress",
|
||||||
|
"source_artifact",
|
||||||
|
["congress"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"score_run",
|
||||||
|
sa.Column("ingest_run_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("classifier_version", sa.String(), nullable=True),
|
||||||
|
sa.Column("scoring_version", sa.String(), nullable=True),
|
||||||
|
sa.Column("included_vote_count", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("excluded_vote_count", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["ingest_run_id"],
|
||||||
|
[f"{schema}.ingest_run.id"],
|
||||||
|
name=op.f("fk_score_run_ingest_run_id_ingest_run"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_score_run")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"legislator_score",
|
||||||
|
sa.Column("score_run_id", sa.Integer(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_score_score_run_id"),
|
||||||
|
"legislator_score",
|
||||||
|
["score_run_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
op.f("fk_legislator_score_score_run_id_score_run"),
|
||||||
|
"legislator_score",
|
||||||
|
"score_run",
|
||||||
|
["score_run_id"],
|
||||||
|
["id"],
|
||||||
|
source_schema=schema,
|
||||||
|
referent_schema=schema,
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("source_datetime_raw", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text", sa.Column("text_url_xml", sa.String(), nullable=True), schema=schema
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text", sa.Column("text_url_pdf", sa.String(), nullable=True), schema=schema
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("text_url_html", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("source_artifact_id", sa.Integer(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
op.f("fk_bill_text_source_artifact_id_source_artifact"),
|
||||||
|
"bill_text",
|
||||||
|
"source_artifact",
|
||||||
|
["source_artifact_id"],
|
||||||
|
["id"],
|
||||||
|
source_schema=schema,
|
||||||
|
referent_schema=schema,
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"bill_action",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("sequence", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("action_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("action_time", sa.String(), nullable=True),
|
||||||
|
sa.Column("action_text", sa.String(), nullable=False),
|
||||||
|
sa.Column("action_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("action_code", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_system_code", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_system_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_artifact_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_action_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["source_artifact_id"],
|
||||||
|
[f"{schema}.source_artifact.id"],
|
||||||
|
name=op.f("fk_bill_action_source_artifact_id_source_artifact"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_action")),
|
||||||
|
sa.UniqueConstraint("bill_id", "sequence", name="uq_bill_action_bill_id_sequence"),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"bill_action_recorded_vote",
|
||||||
|
sa.Column("bill_action_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("session_number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("roll_number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("vote_datetime", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("vote_url", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_action_id"],
|
||||||
|
[f"{schema}.bill_action.id"],
|
||||||
|
name=op.f("fk_bill_action_recorded_vote_bill_action_id_bill_action"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_action_recorded_vote")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_action_id",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
name="uq_bill_action_recorded_vote_match_key",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"bill_relation",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("related_bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("relationship_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("identified_by", sa.String(), nullable=True),
|
||||||
|
sa.Column("latest_action_date", sa.Date(), nullable=True),
|
||||||
|
sa.Column("latest_action_text", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_relation_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["related_bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_relation_related_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_relation")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_relation_bill_id",
|
||||||
|
"bill_relation",
|
||||||
|
["bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_relation_related_bill_id",
|
||||||
|
"bill_relation",
|
||||||
|
["related_bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"amendment",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("amendment_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("description", sa.String(), nullable=True),
|
||||||
|
sa.Column("purpose", sa.String(), nullable=True),
|
||||||
|
sa.Column("amended_bill_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("amended_amendment_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("source_path", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_artifact_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["amended_amendment_id"],
|
||||||
|
[f"{schema}.amendment.id"],
|
||||||
|
name=op.f("fk_amendment_amended_amendment_id_amendment"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["amended_bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_amendment_amended_bill_id_bill"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["source_artifact_id"],
|
||||||
|
[f"{schema}.source_artifact.id"],
|
||||||
|
name=op.f("fk_amendment_source_artifact_id_source_artifact"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_amendment")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"amendment_type",
|
||||||
|
"number",
|
||||||
|
name="uq_amendment_congress_type_number",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"amendment_action",
|
||||||
|
sa.Column("amendment_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("sequence", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("action_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("action_time", sa.String(), nullable=True),
|
||||||
|
sa.Column("action_text", sa.String(), nullable=False),
|
||||||
|
sa.Column("action_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("action_code", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_system_code", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_system_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("source_artifact_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["amendment_id"],
|
||||||
|
[f"{schema}.amendment.id"],
|
||||||
|
name=op.f("fk_amendment_action_amendment_id_amendment"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["source_artifact_id"],
|
||||||
|
[f"{schema}.source_artifact.id"],
|
||||||
|
name=op.f("fk_amendment_action_source_artifact_id_source_artifact"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_amendment_action")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"amendment_id",
|
||||||
|
"sequence",
|
||||||
|
name="uq_amendment_action_amendment_id_sequence",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"amendment_action_recorded_vote",
|
||||||
|
sa.Column("amendment_action_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("session_number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("roll_number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("vote_datetime", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("vote_url", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["amendment_action_id"],
|
||||||
|
[f"{schema}.amendment_action.id"],
|
||||||
|
name=op.f(
|
||||||
|
"fk_amendment_action_recorded_vote_amendment_action_id_amendment_action"
|
||||||
|
),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_amendment_action_recorded_vote")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"amendment_action_id",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
name="uq_amendment_action_recorded_vote_match_key",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
op.drop_constraint(
|
||||||
|
"uq_vote_congress_chamber_session_number",
|
||||||
|
"vote",
|
||||||
|
schema=schema,
|
||||||
|
type_="unique",
|
||||||
|
)
|
||||||
|
op.alter_column("vote", "session", new_column_name="session_year", schema=schema)
|
||||||
|
op.alter_column("vote", "number", new_column_name="roll_number", schema=schema)
|
||||||
|
op.add_column("vote", sa.Column("session_number", sa.Integer(), nullable=True), schema=schema)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column("vote_datetime", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote", sa.Column("raw_vote_source_url", sa.String(), nullable=True), schema=schema
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column("raw_bill_ref", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column(
|
||||||
|
"raw_amendment_ref",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column(
|
||||||
|
"raw_nomination_ref",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column(
|
||||||
|
"raw_treaty_ref",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"vote",
|
||||||
|
sa.Column("raw_vote_source_artifact_id", sa.Integer(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
op.f("fk_vote_raw_vote_source_artifact_id_source_artifact"),
|
||||||
|
"vote",
|
||||||
|
"source_artifact",
|
||||||
|
["raw_vote_source_artifact_id"],
|
||||||
|
["id"],
|
||||||
|
source_schema=schema,
|
||||||
|
referent_schema=schema,
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
f"""
|
||||||
|
UPDATE {schema}.vote
|
||||||
|
SET session_number = session_year - (((congress - 1) * 2) + 1789) + 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
op.alter_column("vote", "session_number", nullable=False, schema=schema)
|
||||||
|
op.create_unique_constraint(
|
||||||
|
"uq_vote_congress_chamber_session_number_roll_number",
|
||||||
|
"vote",
|
||||||
|
["congress", "chamber", "session_number", "roll_number"],
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
op.f("fk_vote_bill_id_bill"),
|
||||||
|
"vote",
|
||||||
|
schema=schema,
|
||||||
|
type_="foreignkey",
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_vote_bill_text_id_bill_text",
|
||||||
|
"vote",
|
||||||
|
schema=schema,
|
||||||
|
type_="foreignkey",
|
||||||
|
)
|
||||||
|
op.drop_index("ix_vote_bill_text_id", table_name="vote", schema=schema)
|
||||||
|
op.drop_column("vote", "bill_id", schema=schema)
|
||||||
|
op.drop_column("vote", "bill_text_id", schema=schema)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"vote_action_match",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("action_scope", sa.String(), nullable=False),
|
||||||
|
sa.Column("bill_action_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("amendment_action_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("is_selected", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("match_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("match_reason", sa.String(), nullable=True),
|
||||||
|
sa.Column("match_confidence", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["amendment_action_id"],
|
||||||
|
[f"{schema}.amendment_action.id"],
|
||||||
|
name=op.f("fk_vote_action_match_amendment_action_id_amendment_action"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_action_id"],
|
||||||
|
[f"{schema}.bill_action.id"],
|
||||||
|
name=op.f("fk_vote_action_match_bill_action_id_bill_action"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_action_match_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_action_match")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_action_match_vote_id",
|
||||||
|
"vote_action_match",
|
||||||
|
["vote_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"uq_vote_action_match_selected_vote_id",
|
||||||
|
"vote_action_match",
|
||||||
|
["vote_id"],
|
||||||
|
unique=True,
|
||||||
|
schema=schema,
|
||||||
|
postgresql_where=sa.text("is_selected"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote_classification",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("subject_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("measure_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("measure_subtype", sa.String(), nullable=True),
|
||||||
|
sa.Column("measure_function", sa.String(), nullable=True),
|
||||||
|
sa.Column("vote_relationship", sa.String(), nullable=False),
|
||||||
|
sa.Column("is_legislation_related", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("is_direct_vote_on_legislative_text", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("is_substantive_policy_vote", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("is_lawmaking_vehicle", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("is_special_rule", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("classification_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("classification_confidence_reason", sa.String(), nullable=True),
|
||||||
|
sa.Column("confidence", sa.String(), nullable=False),
|
||||||
|
sa.Column("classified_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("classification_version", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_classification_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_classification")),
|
||||||
|
sa.UniqueConstraint("vote_id", name=op.f("uq_vote_classification_vote_id")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_classification_subject_type",
|
||||||
|
"vote_classification",
|
||||||
|
["subject_type"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote_measure_link",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("measure_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("role", sa.String(), nullable=False),
|
||||||
|
sa.Column("source", sa.String(), nullable=False),
|
||||||
|
sa.Column("confidence", sa.String(), nullable=False),
|
||||||
|
sa.Column("notes", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["measure_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_vote_measure_link_measure_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_measure_link")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_measure_link_vote_id",
|
||||||
|
"vote_measure_link",
|
||||||
|
["vote_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
op.f("fk_vote_measure_link_vote_id_vote"),
|
||||||
|
"vote_measure_link",
|
||||||
|
"vote",
|
||||||
|
["vote_id"],
|
||||||
|
["id"],
|
||||||
|
source_schema=schema,
|
||||||
|
referent_schema=schema,
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote_text_target",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("text_target_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("voted_text_version_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("resulting_text_version_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("related_amendment_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("text_target_basis", sa.String(), nullable=False),
|
||||||
|
sa.Column("text_resolution_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("text_resolution_confidence_reason", sa.String(), nullable=True),
|
||||||
|
sa.Column("confidence", sa.String(), nullable=False),
|
||||||
|
sa.Column("notes", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["related_amendment_id"],
|
||||||
|
[f"{schema}.amendment.id"],
|
||||||
|
name=op.f("fk_vote_text_target_related_amendment_id_amendment"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["resulting_text_version_id"],
|
||||||
|
[f"{schema}.bill_text.id"],
|
||||||
|
name=op.f("fk_vote_text_target_resulting_text_version_id_bill_text"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_text_target_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["voted_text_version_id"],
|
||||||
|
[f"{schema}.bill_text.id"],
|
||||||
|
name=op.f("fk_vote_text_target_voted_text_version_id_bill_text"),
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_text_target")),
|
||||||
|
sa.UniqueConstraint("vote_id", name=op.f("uq_vote_text_target_vote_id")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote_position_meaning",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("yea_effect", sa.String(), nullable=False),
|
||||||
|
sa.Column("nay_effect", sa.String(), nullable=False),
|
||||||
|
sa.Column("present_effect", sa.String(), nullable=False),
|
||||||
|
sa.Column("polarity_confidence", sa.String(), nullable=False),
|
||||||
|
sa.Column("polarity_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("notes", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_position_meaning_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_position_meaning")),
|
||||||
|
sa.UniqueConstraint("vote_id", name=op.f("uq_vote_position_meaning_vote_id")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote_context_audit",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("step", sa.String(), nullable=False),
|
||||||
|
sa.Column("message", sa.String(), nullable=False),
|
||||||
|
sa.Column("severity", sa.String(), nullable=False),
|
||||||
|
sa.Column("source_path", sa.String(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_context_audit_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote_context_audit")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_context_audit_vote_id",
|
||||||
|
"vote_context_audit",
|
||||||
|
["vote_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
raise NotImplementedError("Downgrade is not supported for canonical vote context v3.")
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
"""add supporting indexes for congress vote context and scoring.
|
||||||
|
|
||||||
|
Revision ID: a7b91c4e2d30
|
||||||
|
Revises: 1f8c0e7a9d21
|
||||||
|
Create Date: 2026-04-26 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "a7b91c4e2d30"
|
||||||
|
down_revision: str | None = "1f8c0e7a9d21"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def _dedupe_source_artifacts() -> None:
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
f"""
|
||||||
|
CREATE TEMP TABLE tmp_source_artifact_dups AS
|
||||||
|
WITH ranked AS (
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
first_value(id) OVER (
|
||||||
|
PARTITION BY ingest_run_id, local_path, sha256
|
||||||
|
ORDER BY id
|
||||||
|
) AS keep_id,
|
||||||
|
row_number() OVER (
|
||||||
|
PARTITION BY ingest_run_id, local_path, sha256
|
||||||
|
ORDER BY id
|
||||||
|
) AS rn
|
||||||
|
FROM {schema}.source_artifact
|
||||||
|
WHERE ingest_run_id IS NOT NULL
|
||||||
|
)
|
||||||
|
SELECT id, keep_id
|
||||||
|
FROM ranked
|
||||||
|
WHERE rn > 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for table_name, column_name in (
|
||||||
|
("bill_text", "source_artifact_id"),
|
||||||
|
("bill_action", "source_artifact_id"),
|
||||||
|
("amendment", "source_artifact_id"),
|
||||||
|
("amendment_action", "source_artifact_id"),
|
||||||
|
("vote", "raw_vote_source_artifact_id"),
|
||||||
|
):
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
f"""
|
||||||
|
UPDATE {schema}.{table_name} AS target
|
||||||
|
SET {column_name} = d.keep_id
|
||||||
|
FROM tmp_source_artifact_dups AS d
|
||||||
|
WHERE target.{column_name} = d.id
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
f"""
|
||||||
|
DELETE FROM {schema}.source_artifact AS source_artifact
|
||||||
|
USING tmp_source_artifact_dups AS d
|
||||||
|
WHERE source_artifact.id = d.id
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
op.execute(sa.text("DROP TABLE tmp_source_artifact_dups"))
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
_dedupe_source_artifacts()
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
"uq_source_artifact_ingest_identity",
|
||||||
|
"source_artifact",
|
||||||
|
["ingest_run_id", "local_path", "sha256"],
|
||||||
|
unique=True,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_action_recorded_vote_match_tuple",
|
||||||
|
"bill_action_recorded_vote",
|
||||||
|
["congress", "chamber", "session_number", "roll_number"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_amendment_action_recorded_vote_match_tuple",
|
||||||
|
"amendment_action_recorded_vote",
|
||||||
|
["congress", "chamber", "session_number", "roll_number"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_classification_eligible_vote_id",
|
||||||
|
"vote_classification",
|
||||||
|
["vote_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
postgresql_where=sa.text(
|
||||||
|
"subject_type = 'measure' "
|
||||||
|
"AND vote_relationship = 'direct_text_vote' "
|
||||||
|
"AND is_direct_vote_on_legislative_text "
|
||||||
|
"AND is_substantive_policy_vote "
|
||||||
|
"AND NOT is_special_rule"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_measure_link_vote_id_role",
|
||||||
|
"vote_measure_link",
|
||||||
|
["vote_id", "role"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_measure_link_measure_id_role",
|
||||||
|
"vote_measure_link",
|
||||||
|
["measure_id", "role"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_text_target_voted_text_version_id",
|
||||||
|
"vote_text_target",
|
||||||
|
["voted_text_version_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
postgresql_where=sa.text("voted_text_version_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_context_audit_severity_vote_id",
|
||||||
|
"vote_context_audit",
|
||||||
|
["severity", "vote_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_current_chamber",
|
||||||
|
"legislator",
|
||||||
|
["current_chamber"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
op.drop_index("ix_legislator_current_chamber", table_name="legislator", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_vote_context_audit_severity_vote_id",
|
||||||
|
table_name="vote_context_audit",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_vote_text_target_voted_text_version_id",
|
||||||
|
table_name="vote_text_target",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_vote_measure_link_measure_id_role",
|
||||||
|
table_name="vote_measure_link",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_vote_measure_link_vote_id_role",
|
||||||
|
table_name="vote_measure_link",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_vote_classification_eligible_vote_id",
|
||||||
|
table_name="vote_classification",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_amendment_action_recorded_vote_match_tuple",
|
||||||
|
table_name="amendment_action_recorded_vote",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_bill_action_recorded_vote_match_tuple",
|
||||||
|
table_name="bill_action_recorded_vote",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"uq_source_artifact_ingest_identity",
|
||||||
|
table_name="source_artifact",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""adding PostTopic.
|
||||||
|
|
||||||
|
Revision ID: 032e26bbfcb5
|
||||||
|
Revises: a7b91c4e2d30
|
||||||
|
Create Date: 2026-04-26 14:34:35.688341
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "032e26bbfcb5"
|
||||||
|
down_revision: str | None = "a7b91c4e2d30"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"post_topic",
|
||||||
|
sa.Column("post_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("topic_id", sa.SmallInteger(), nullable=False),
|
||||||
|
sa.Column("topic_label", sa.String(), nullable=True),
|
||||||
|
sa.Column("model_version", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_post_topic")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_post_topic_post_id", "post_topic", ["post_id"], unique=False, schema=schema
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index("ix_post_topic_post_id", table_name="post_topic", schema=schema)
|
||||||
|
op.drop_table("post_topic", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""adding PG Vector.
|
||||||
|
|
||||||
|
Revision ID: b9360b0b0c22
|
||||||
|
Revises: 032e26bbfcb5
|
||||||
|
Create Date: 2026-04-26 14:35:08.770128
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b9360b0b0c22"
|
||||||
|
down_revision: str | None = "032e26bbfcb5"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||||
+138
@@ -0,0 +1,138 @@
|
|||||||
|
"""Alembic."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from alembic.script import write_hooks
|
||||||
|
from sqlalchemy.schema import CreateSchema
|
||||||
|
|
||||||
|
from pipelines.common import bash_wrapper
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
base_class: type[DeclarativeBase] = config.attributes.get("base")
|
||||||
|
if base_class is None:
|
||||||
|
error = "No base class provided. Use the database CLI to run alembic commands."
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
target_metadata = base_class.metadata
|
||||||
|
logging.basicConfig(
|
||||||
|
level="DEBUG",
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
||||||
|
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("dynamic_schema")
|
||||||
|
def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Dynamic schema."""
|
||||||
|
original_file = Path(filename).read_text()
|
||||||
|
schema_name = base_class.schema_name
|
||||||
|
dynamic_schema_file_part1 = original_file.replace(
|
||||||
|
f"schema='{schema_name}'", "schema=schema"
|
||||||
|
)
|
||||||
|
dynamic_schema_file = dynamic_schema_file_part1.replace(
|
||||||
|
f"'{schema_name}.", "f'{schema}."
|
||||||
|
)
|
||||||
|
Path(filename).write_text(dynamic_schema_file)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("import_postgresql")
|
||||||
|
def import_postgresql(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Add postgresql dialect import when postgresql types are used."""
|
||||||
|
content = Path(filename).read_text()
|
||||||
|
if (
|
||||||
|
"postgresql." in content
|
||||||
|
and "from sqlalchemy.dialects import postgresql" not in content
|
||||||
|
):
|
||||||
|
content = content.replace(
|
||||||
|
"import sqlalchemy as sa\n",
|
||||||
|
"import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n",
|
||||||
|
)
|
||||||
|
Path(filename).write_text(content)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("ruff")
|
||||||
|
def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Docstring for ruff_check_and_format."""
|
||||||
|
bash_wrapper(f"ruff check --fix {filename}")
|
||||||
|
bash_wrapper(f"ruff format {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def include_name(
|
||||||
|
name: str | None,
|
||||||
|
type_: Literal[
|
||||||
|
"schema",
|
||||||
|
"table",
|
||||||
|
"column",
|
||||||
|
"index",
|
||||||
|
"unique_constraint",
|
||||||
|
"foreign_key_constraint",
|
||||||
|
],
|
||||||
|
_parent_names: MutableMapping[
|
||||||
|
Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None
|
||||||
|
],
|
||||||
|
) -> bool:
|
||||||
|
"""Filter tables to be included in the migration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the table.
|
||||||
|
type_ (str): The type of the table.
|
||||||
|
_parent_names (MutableMapping): The names of the parent tables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the table should be included, False otherwise.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if type_ == "schema":
|
||||||
|
# allows a database with multiple schemas to have separate alembic revisions
|
||||||
|
return name == target_metadata.schema
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
env_prefix = config.attributes.get("env_prefix", "POSTGRES")
|
||||||
|
connectable = get_postgres_engine(name=env_prefix)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
schema = base_class.schema_name
|
||||||
|
if not connectable.dialect.has_schema(connection, schema):
|
||||||
|
answer = input(f"Schema {schema!r} does not exist. Create it? [y/N] ")
|
||||||
|
if answer.lower() != "y":
|
||||||
|
error = f"Schema {schema!r} does not exist. Exiting."
|
||||||
|
raise SystemExit(error)
|
||||||
|
connection.execute(CreateSchema(schema))
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
include_schemas=True,
|
||||||
|
version_table_schema=schema,
|
||||||
|
include_name=include_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
run_migrations_online()
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""${message}.
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import ${config.attributes["base"].__name__}
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: str | None = ${repr(down_revision)}
|
||||||
|
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
|
||||||
|
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
|
||||||
|
|
||||||
|
schema=${config.attributes["base"].__name__}.schema_name
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
+123
@@ -0,0 +1,123 @@
|
|||||||
|
"""CLI wrapper around alembic for multi-database support.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
database <db_name> <command> [args...]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
database van_inventory upgrade head
|
||||||
|
database van_inventory downgrade head-1
|
||||||
|
database van_inventory revision --autogenerate -m "add meals table"
|
||||||
|
database van_inventory check
|
||||||
|
database richie check
|
||||||
|
database richie upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from alembic.config import CommandLine, Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatabaseConfig:
|
||||||
|
"""Configuration for a database."""
|
||||||
|
|
||||||
|
env_prefix: str
|
||||||
|
version_location: str
|
||||||
|
base_module: str
|
||||||
|
base_class_name: str
|
||||||
|
models_module: str
|
||||||
|
script_location: str = "alembic"
|
||||||
|
file_template: str = "%%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s"
|
||||||
|
|
||||||
|
def get_base(self) -> type[DeclarativeBase]:
|
||||||
|
"""Import and return the Base class."""
|
||||||
|
module = import_module(self.base_module)
|
||||||
|
return getattr(module, self.base_class_name)
|
||||||
|
|
||||||
|
def import_models(self) -> None:
|
||||||
|
"""Import ORM models so alembic autogenerate can detect them."""
|
||||||
|
import_module(self.models_module)
|
||||||
|
|
||||||
|
def alembic_config(self) -> Config:
|
||||||
|
"""Build an alembic Config for this database."""
|
||||||
|
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
|
||||||
|
from alembic.config import Config as AlembicConfig # noqa: PLC0415
|
||||||
|
|
||||||
|
cfg = AlembicConfig()
|
||||||
|
cfg.set_main_option("script_location", self.script_location)
|
||||||
|
cfg.set_main_option("file_template", self.file_template)
|
||||||
|
cfg.set_main_option("prepend_sys_path", ".")
|
||||||
|
cfg.set_main_option("version_path_separator", "os")
|
||||||
|
cfg.set_main_option("version_locations", self.version_location)
|
||||||
|
cfg.set_main_option("revision_environment", "true")
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "hooks", "dynamic_schema,import_postgresql,ruff"
|
||||||
|
)
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "dynamic_schema.type", "dynamic_schema"
|
||||||
|
)
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "import_postgresql.type", "import_postgresql"
|
||||||
|
)
|
||||||
|
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
|
||||||
|
cfg.attributes["base"] = self.get_base()
|
||||||
|
cfg.attributes["env_prefix"] = self.env_prefix
|
||||||
|
self.import_models()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
DATABASES: dict[str, DatabaseConfig] = {
|
||||||
|
"data_science_dev": DatabaseConfig(
|
||||||
|
env_prefix="DATA_SCIENCE_DEV",
|
||||||
|
version_location="alembic/data_science_dev/versions",
|
||||||
|
base_module="pipelines.orm.data_science_dev.base",
|
||||||
|
base_class_name="DataScienceDevBase",
|
||||||
|
models_module="pipelines.orm.data_science_dev.models",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Multi-database alembic wrapper.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(
|
||||||
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def main(
|
||||||
|
ctx: typer.Context,
|
||||||
|
db_name: Annotated[
|
||||||
|
str, typer.Argument(help=f"Database name. Options: {', '.join(DATABASES)}")
|
||||||
|
],
|
||||||
|
command: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(
|
||||||
|
help="Alembic command (upgrade, downgrade, revision, check, etc.)"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Run an alembic command against the specified database."""
|
||||||
|
db_config = DATABASES.get(db_name)
|
||||||
|
if not db_config:
|
||||||
|
typer.echo(
|
||||||
|
f"Unknown database: {db_name!r}. Available: {', '.join(DATABASES)}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
alembic_cfg = db_config.alembic_config()
|
||||||
|
|
||||||
|
cmd_line = CommandLine()
|
||||||
|
options = cmd_line.parser.parse_args([command, *ctx.args])
|
||||||
|
cmd_line.run_cmd(alembic_cfg, options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""common."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from os import getenv
|
||||||
|
from subprocess import PIPE, Popen
|
||||||
|
|
||||||
|
from apprise import Apprise
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logger(level: str = "INFO") -> None:
|
||||||
|
"""Configure the logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level (str, optional): The logging level. Defaults to "INFO".
|
||||||
|
"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
||||||
|
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bash_wrapper(command: str) -> tuple[str, int]:
|
||||||
|
"""Execute a bash command and capture the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command (str): The bash command to be executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, int]: A tuple containing the output of the command (stdout) as a string,
|
||||||
|
the error output (stderr) as a string (optional), and the return code as an integer.
|
||||||
|
"""
|
||||||
|
# This is a acceptable risk
|
||||||
|
process = Popen(command.split(), stdout=PIPE, stderr=PIPE)
|
||||||
|
output, error = process.communicate()
|
||||||
|
if error:
|
||||||
|
logger.error(f"{error=}")
|
||||||
|
return error.decode(), process.returncode
|
||||||
|
|
||||||
|
return output.decode(), process.returncode
|
||||||
|
|
||||||
|
|
||||||
|
def signal_alert(body: str, title: str = "") -> None:
|
||||||
|
"""Send a signal alert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body (str): The body of the alert.
|
||||||
|
title (str, optional): The title of the alert. Defaults to "".
|
||||||
|
"""
|
||||||
|
apprise_client = Apprise()
|
||||||
|
|
||||||
|
from_phone = getenv("SIGNAL_ALERT_FROM_PHONE")
|
||||||
|
to_phone = getenv("SIGNAL_ALERT_TO_PHONE")
|
||||||
|
if not from_phone or not to_phone:
|
||||||
|
logger.info("SIGNAL_ALERT_FROM_PHONE or SIGNAL_ALERT_TO_PHONE not set")
|
||||||
|
return
|
||||||
|
|
||||||
|
apprise_client.add(f"signal://localhost:8989/{from_phone}/{to_phone}")
|
||||||
|
|
||||||
|
apprise_client.notify(title=title, body=body)
|
||||||
|
|
||||||
|
|
||||||
|
def utcnow() -> datetime:
|
||||||
|
"""Get the current UTC time."""
|
||||||
|
return datetime.now(tz=UTC)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from os import getenv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
@@ -68,15 +69,54 @@ class BenchmarkConfig:
|
|||||||
return cls(**raw)
|
return cls(**raw)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenAIConfig:
|
||||||
|
"""OpenAI API configuration."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
openai_project_id: str
|
||||||
|
openai_chat_completions_url: str
|
||||||
|
model: str
|
||||||
|
timeout_seconds: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> OpenAIConfig:
|
||||||
|
"""Load OpenAI config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text()).get("openai", {})
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise KeyError(message)
|
||||||
|
return cls(
|
||||||
|
api_key=api_key,
|
||||||
|
openai_project_id=raw.get(
|
||||||
|
"openai_project_id", "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
),
|
||||||
|
openai_chat_completions_url=raw.get(
|
||||||
|
"openai_chat_completions_url",
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
),
|
||||||
|
model=raw.get("model", "gpt-5.4-mini"),
|
||||||
|
timeout_seconds=raw.get("timeout_seconds", 60),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dir() -> Path:
|
def get_config_dir() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config directory."""
|
||||||
return Path(__file__).resolve().parent.parent.parent / "config"
|
return Path(__file__).resolve().parents[2] / "config"
|
||||||
|
|
||||||
|
|
||||||
def default_config_path() -> Path:
|
def default_config_path() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config file."""
|
||||||
return get_config_dir() / "config.toml"
|
return get_config_dir() / "config.toml"
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_config(config_path: Path | None = None) -> OpenAIConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return OpenAIConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
|
||||||
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = default_config_path()
|
config_path = default_config_path()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||||
@@ -0,0 +1,235 @@
|
|||||||
|
"""Docker container lifecycle management for BERTopic jobs on Jeeves."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
JOBMode = Literal["train", "infer"]
|
||||||
|
IMAGE_NAME = "bert-topic:latest"
|
||||||
|
REPO_DIR = Path(__file__).resolve().parents[3]
|
||||||
|
DEFAULT_CACHE_ROOT = Path("/zfs/storage/main/ds_thing/models/bert_topic")
|
||||||
|
DEFAULT_POSTGRES_SOCKET_DIR = Path("/run/postgresql")
|
||||||
|
DB_ENV_VARS = (
|
||||||
|
"DATA_SCIENCE_DEV_DB",
|
||||||
|
"DATA_SCIENCE_DEV_HOST",
|
||||||
|
"DATA_SCIENCE_DEV_PORT",
|
||||||
|
"DATA_SCIENCE_DEV_USER",
|
||||||
|
"DATA_SCIENCE_DEV_PASSWORD",
|
||||||
|
)
|
||||||
|
|
||||||
|
app = typer.Typer(help="BERTopic container management.")
|
||||||
|
|
||||||
|
|
||||||
|
def _container_name(mode: JOBMode) -> str:
|
||||||
|
"""Return the Docker container name for the selected BERTopic job."""
|
||||||
|
return f"bert-topic-{mode}"
|
||||||
|
|
||||||
|
|
||||||
|
def _module_name(mode: JOBMode) -> str:
|
||||||
|
"""Return the Python module to run inside the container."""
|
||||||
|
return f"pipelines.bert_topic.{mode}"
|
||||||
|
|
||||||
|
|
||||||
|
def _env_args(*, use_postgres_socket: bool) -> list[str]:
|
||||||
|
"""Pass through database environment variables from the host shell."""
|
||||||
|
required = [
|
||||||
|
"DATA_SCIENCE_DEV_DB",
|
||||||
|
"DATA_SCIENCE_DEV_PORT",
|
||||||
|
"DATA_SCIENCE_DEV_USER",
|
||||||
|
]
|
||||||
|
if not use_postgres_socket:
|
||||||
|
required.append("DATA_SCIENCE_DEV_HOST")
|
||||||
|
missing = [name for name in required if not os.getenv(name)]
|
||||||
|
if missing:
|
||||||
|
message = "Missing required database environment variables: " + ", ".join(
|
||||||
|
missing
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
args: list[str] = []
|
||||||
|
if use_postgres_socket:
|
||||||
|
args.extend(["-e", f"DATA_SCIENCE_DEV_HOST={DEFAULT_POSTGRES_SOCKET_DIR}"])
|
||||||
|
for name in DB_ENV_VARS:
|
||||||
|
if use_postgres_socket and name == "DATA_SCIENCE_DEV_HOST":
|
||||||
|
continue
|
||||||
|
if os.getenv(name):
|
||||||
|
args.extend(["-e", name])
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def build_image() -> None:
|
||||||
|
"""Build the BERTopic Docker image."""
|
||||||
|
dockerfile = REPO_DIR / "pipelines/containers/docker_files/Dockerfile.bert_topic"
|
||||||
|
logger.info("Building BERTopic image: %s", IMAGE_NAME)
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"docker",
|
||||||
|
"build",
|
||||||
|
"--network",
|
||||||
|
"host",
|
||||||
|
"-f",
|
||||||
|
str(dockerfile),
|
||||||
|
"-t",
|
||||||
|
IMAGE_NAME,
|
||||||
|
str(REPO_DIR),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
message = (
|
||||||
|
"Failed to build BERTopic image. "
|
||||||
|
f"docker build stderr:\n{result.stderr.strip()}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
logger.info("Image built: %s", IMAGE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_job(*, mode: JOBMode) -> None:
|
||||||
|
"""Stop and remove the BERTopic container for the selected mode."""
|
||||||
|
container_name = _container_name(mode)
|
||||||
|
logger.info("Stopping BERTopic container: %s", container_name)
|
||||||
|
subprocess.run(["docker", "stop", container_name], capture_output=True, check=False)
|
||||||
|
subprocess.run(
|
||||||
|
["docker", "rm", "-f", container_name], capture_output=True, check=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_job(
|
||||||
|
*,
|
||||||
|
mode: JOBMode,
|
||||||
|
cache_root: Path = DEFAULT_CACHE_ROOT,
|
||||||
|
postgres_socket_dir: Path = DEFAULT_POSTGRES_SOCKET_DIR,
|
||||||
|
detach: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Run BERTopic training or inference in Docker on Jeeves."""
|
||||||
|
cache_root = cache_root.resolve()
|
||||||
|
cache_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
postgres_socket_dir = postgres_socket_dir.resolve()
|
||||||
|
stop_job(mode=mode)
|
||||||
|
use_postgres_socket = postgres_socket_dir.exists()
|
||||||
|
|
||||||
|
command = [
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"--name",
|
||||||
|
_container_name(mode),
|
||||||
|
"--ipc=host",
|
||||||
|
"-v",
|
||||||
|
f"{cache_root}:/cache",
|
||||||
|
*_env_args(use_postgres_socket=use_postgres_socket),
|
||||||
|
IMAGE_NAME,
|
||||||
|
_module_name(mode),
|
||||||
|
]
|
||||||
|
if use_postgres_socket:
|
||||||
|
command[7:7] = ["-v", f"{postgres_socket_dir}:{DEFAULT_POSTGRES_SOCKET_DIR}"]
|
||||||
|
if detach:
|
||||||
|
command.insert(2, "-d")
|
||||||
|
|
||||||
|
logger.info("Starting BERTopic %s container", mode)
|
||||||
|
logger.info(" Cache root: %s", cache_root)
|
||||||
|
if use_postgres_socket:
|
||||||
|
logger.info(" Postgres socket: %s", postgres_socket_dir)
|
||||||
|
result = subprocess.run(command, text=True, capture_output=detach, check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
detail = (
|
||||||
|
result.stderr.strip() if result.stderr else f"exit code {result.returncode}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"BERTopic container failed to start: {detail}")
|
||||||
|
if detach:
|
||||||
|
logger.info("Container started: %s", result.stdout.strip()[:12])
|
||||||
|
else:
|
||||||
|
logger.info("BERTopic %s run complete", mode)
|
||||||
|
|
||||||
|
|
||||||
|
def logs_job(*, mode: JOBMode) -> str | None:
|
||||||
|
"""Return recent logs from the BERTopic container, or None if absent."""
|
||||||
|
result = subprocess.run(
|
||||||
|
["docker", "logs", "--tail", "100", _container_name(mode)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return result.stdout + result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def build(
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Build the BERTopic Docker image."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
build_image()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("run")
|
||||||
|
def run_job_command(
|
||||||
|
mode: Annotated[JOBMode, typer.Option(help="Which BERTopic job to run")] = "train",
|
||||||
|
cache_root: Annotated[
|
||||||
|
Path, typer.Option(help="Host path mounted to /cache for model and HF cache")
|
||||||
|
] = DEFAULT_CACHE_ROOT,
|
||||||
|
postgres_socket_dir: Annotated[
|
||||||
|
Path, typer.Option(help="Host Postgres socket directory to mount into the container")
|
||||||
|
] = DEFAULT_POSTGRES_SOCKET_DIR,
|
||||||
|
detach: Annotated[
|
||||||
|
bool, typer.Option(help="Start the container in the background")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run BERTopic training or inference inside Docker."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
start_job(
|
||||||
|
mode=mode,
|
||||||
|
cache_root=cache_root,
|
||||||
|
postgres_socket_dir=postgres_socket_dir,
|
||||||
|
detach=detach,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("stop")
|
||||||
|
def stop_job_command(
|
||||||
|
mode: Annotated[
|
||||||
|
JOBMode, typer.Option(help="Which BERTopic container to stop")
|
||||||
|
] = "train",
|
||||||
|
) -> None:
|
||||||
|
"""Stop and remove the BERTopic container."""
|
||||||
|
stop_job(mode=mode)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("logs")
|
||||||
|
def logs_job_command(
|
||||||
|
mode: Annotated[
|
||||||
|
JOBMode, typer.Option(help="Which BERTopic container logs to show")
|
||||||
|
] = "train",
|
||||||
|
) -> None:
|
||||||
|
"""Show recent logs from the BERTopic container."""
|
||||||
|
output = logs_job(mode=mode)
|
||||||
|
if output is None:
|
||||||
|
typer.echo(f"No BERTopic container found for mode={mode}.")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
typer.echo(output)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
FROM python:3.12-bookworm
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PIP_NO_CACHE_DIR=1
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
git \
|
||||||
|
libgomp1 \
|
||||||
|
libpq-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pipelines ./pipelines
|
||||||
|
|
||||||
|
RUN python -m pip install --upgrade pip setuptools wheel && \
|
||||||
|
python -m pip install \
|
||||||
|
torch \
|
||||||
|
--index-url https://download.pytorch.org/whl/cpu && \
|
||||||
|
python -m pip install \
|
||||||
|
typer \
|
||||||
|
sqlalchemy \
|
||||||
|
bertopic \
|
||||||
|
sentence-transformers \
|
||||||
|
scikit-learn \
|
||||||
|
pandas \
|
||||||
|
numpy \
|
||||||
|
"psycopg[binary]"
|
||||||
|
|
||||||
|
ENV HF_HOME=/cache/huggingface
|
||||||
|
ENV TRANSFORMERS_CACHE=/cache/huggingface
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m"]
|
||||||
|
CMD ["pipelines.bert_topic.train"]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
FROM ghcr.io/unslothai/unsloth:latest
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir typer
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
||||||
|
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||||
|
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
||||||
|
COPY python/__init__.py python/__init__.py
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
|
||||||
@@ -9,7 +9,7 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from pipelines.tools.containers.lib import check_gpu_free
|
from pipelines.pipelines.containers.lib import check_gpu_free
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ def build_image() -> None:
|
|||||||
"docker",
|
"docker",
|
||||||
"build",
|
"build",
|
||||||
"-f",
|
"-f",
|
||||||
str(REPO_DIR / "pipelines/pipelines/tools/Dockerfile.finetune"),
|
str(REPO_DIR / "pipelines/containers/docker_files/Dockerfile.finetune"),
|
||||||
"-t",
|
"-t",
|
||||||
FINETUNE_IMAGE,
|
FINETUNE_IMAGE,
|
||||||
".",
|
".",
|
||||||
@@ -133,7 +133,7 @@ def build() -> None:
|
|||||||
@app.command()
|
@app.command()
|
||||||
def run(
|
def run(
|
||||||
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
|
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
|
||||||
/ "/zfs/storage/data_science/data/finetune_dataset.jsonl",
|
/ "data/finetune_dataset.jsonl",
|
||||||
output_dir: Annotated[
|
output_dir: Annotated[
|
||||||
Path, typer.Option(help="Where to save the trained model")
|
Path, typer.Option(help="Where to save the trained model")
|
||||||
] = REPO_DIR / "data/output/qwen-bill-summarizer",
|
] = REPO_DIR / "data/output/qwen-bill-summarizer",
|
||||||
@@ -0,0 +1,574 @@
|
|||||||
|
"""Calculate legislator topic scores from bill topics and roll-call votes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import (
|
||||||
|
ColumnElement,
|
||||||
|
Integer,
|
||||||
|
Select,
|
||||||
|
and_,
|
||||||
|
case,
|
||||||
|
cast,
|
||||||
|
delete,
|
||||||
|
extract,
|
||||||
|
func,
|
||||||
|
or_,
|
||||||
|
select,
|
||||||
|
tuple_,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.congress_vote_context import create_score_run, finalize_score_run
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
LegislatorScore,
|
||||||
|
SubjectType,
|
||||||
|
Vote,
|
||||||
|
VoteClassification,
|
||||||
|
VoteEffect,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VoteMeasureRole,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteRecord,
|
||||||
|
)
|
||||||
|
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
|
||||||
|
from pipelines.web.scoring import (
|
||||||
|
OPPOSE_POSITIONS,
|
||||||
|
SUPPORT_POSITIONS,
|
||||||
|
normalized_position_expression,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DELETE_BATCH_SIZE = 5_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScoreDiagnostics:
|
||||||
|
"""Counts for the input stages required to calculate legislator scores."""
|
||||||
|
|
||||||
|
bill_topic_rows: int
|
||||||
|
linked_vote_rows: int
|
||||||
|
vote_record_rows: int
|
||||||
|
topic_vote_links: int
|
||||||
|
scorable_vote_records: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LegislatorScoreInput:
|
||||||
|
"""One aggregated score ready to store in legislator_score."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
year: int
|
||||||
|
topic: str
|
||||||
|
score: float
|
||||||
|
supportive: int
|
||||||
|
opposed: int
|
||||||
|
|
||||||
|
|
||||||
|
def create_legislator_score_query(
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Build the aggregate score query from extracted bill topics and vote records."""
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
supportive_vote = _supportive_vote_expression(normalized_vote)
|
||||||
|
opposed_vote = _opposed_vote_expression(normalized_vote)
|
||||||
|
supportive_count = func.sum(supportive_vote)
|
||||||
|
opposed_count = func.sum(opposed_vote)
|
||||||
|
total_count = supportive_count + opposed_count
|
||||||
|
vote_year = cast(extract("year", Vote.vote_date), Integer)
|
||||||
|
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
VoteRecord.legislator_id.label("legislator_id"),
|
||||||
|
vote_year.label("year"),
|
||||||
|
BillTopic.topic.label("topic"),
|
||||||
|
score,
|
||||||
|
supportive_count.label("supportive"),
|
||||||
|
opposed_count.label("opposed"),
|
||||||
|
total_count.label("total"),
|
||||||
|
)
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
.having(total_count > 0)
|
||||||
|
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
if normalized_topics:
|
||||||
|
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> list[LegislatorScoreInput]:
|
||||||
|
"""Run the aggregate query and return score rows."""
|
||||||
|
rows = session.execute(
|
||||||
|
create_legislator_score_query(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
LegislatorScoreInput(
|
||||||
|
legislator_id=int(row.legislator_id),
|
||||||
|
year=int(row.year),
|
||||||
|
topic=str(row.topic),
|
||||||
|
score=float(row.score),
|
||||||
|
supportive=int(row.supportive),
|
||||||
|
opposed=int(row.opposed),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
if row.score is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_score_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> ScoreDiagnostics:
|
||||||
|
"""Count score pipeline inputs for explaining empty score runs."""
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
|
||||||
|
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
eligible_vote_filters = _eligible_vote_filters()
|
||||||
|
|
||||||
|
bill_topic_rows = session.scalar(
|
||||||
|
select(func.count(BillTopic.id)).where(*topic_filters)
|
||||||
|
)
|
||||||
|
linked_vote_rows = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
vote_record_rows = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
topic_vote_links = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
scorable_vote_records = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*vote_filters,
|
||||||
|
*topic_filters,
|
||||||
|
*eligible_vote_filters,
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoreDiagnostics(
|
||||||
|
bill_topic_rows=bill_topic_rows or 0,
|
||||||
|
linked_vote_rows=linked_vote_rows or 0,
|
||||||
|
vote_record_rows=vote_record_rows or 0,
|
||||||
|
topic_vote_links=topic_vote_links or 0,
|
||||||
|
scorable_vote_records=scorable_vote_records or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def store_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
rows: Sequence[LegislatorScoreInput],
|
||||||
|
*,
|
||||||
|
score_run_id: int | None,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Replace matching score rows and insert the newly calculated scores."""
|
||||||
|
if replace_all:
|
||||||
|
session.execute(delete(LegislatorScore))
|
||||||
|
elif rows:
|
||||||
|
keys = [
|
||||||
|
(row.legislator_id, row.year, row.topic)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
|
||||||
|
session.execute(
|
||||||
|
delete(LegislatorScore).where(
|
||||||
|
tuple_(
|
||||||
|
LegislatorScore.legislator_id,
|
||||||
|
LegislatorScore.year,
|
||||||
|
LegislatorScore.topic,
|
||||||
|
).in_(key_batch)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
LegislatorScore(
|
||||||
|
legislator_id=row.legislator_id,
|
||||||
|
year=row.year,
|
||||||
|
topic=row.topic,
|
||||||
|
score=row.score,
|
||||||
|
score_run_id=score_run_id,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _supportive_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _opposed_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _position_matches_effect(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
effect: VoteEffect,
|
||||||
|
) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
|
||||||
|
VotePositionMeaning.yea_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
|
||||||
|
VotePositionMeaning.nay_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote == "present",
|
||||||
|
VotePositionMeaning.present_effect == effect,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
|
||||||
|
normalized: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for topic in topics or []:
|
||||||
|
value = normalize_topic_label(topic)
|
||||||
|
if value and value not in seen:
|
||||||
|
normalized.append(value)
|
||||||
|
seen.add(value)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
|
||||||
|
return [
|
||||||
|
items[index : index + batch_size]
|
||||||
|
for index in range(0, len(items), batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _vote_scope_filters(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
filters.append(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _topic_scope_filters(
|
||||||
|
*,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str],
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
|
||||||
|
if topics:
|
||||||
|
filters.append(BillTopic.topic.in_(list(topics)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _has_score_scope(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str] | None,
|
||||||
|
) -> bool:
|
||||||
|
return congress is not None or bool(bill_ids) or bool(topics)
|
||||||
|
|
||||||
|
|
||||||
|
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
|
||||||
|
return [
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
congress: Annotated[
|
||||||
|
int | None,
|
||||||
|
typer.Option(help="Only score votes from one Congress."),
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only score votes linked to one internal bill.id. Repeatable.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
topics: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
|
||||||
|
] = None,
|
||||||
|
replace_all: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
help="Delete every existing legislator score before inserting. "
|
||||||
|
"Unfiltered runs do this automatically."
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Calculate scores without writing to the database."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts even when rows are calculated."),
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for calculating and storing legislator topic scores."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
rows = collect_legislator_scores(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
logger.info("Calculated %d legislator topic score rows", len(rows))
|
||||||
|
if diagnose or not rows:
|
||||||
|
diagnostics = collect_score_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
_log_diagnostics(diagnostics)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
session.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
score_run = create_score_run(session)
|
||||||
|
should_replace_all = replace_all or not _has_score_scope(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
written = store_legislator_scores(
|
||||||
|
session,
|
||||||
|
rows,
|
||||||
|
score_run_id=score_run.id,
|
||||||
|
replace_all=should_replace_all,
|
||||||
|
)
|
||||||
|
included_vote_count = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
|
||||||
|
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
|
||||||
|
)
|
||||||
|
) or 0
|
||||||
|
total_scoped_votes = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
|
||||||
|
) or 0
|
||||||
|
finalize_score_run(
|
||||||
|
session,
|
||||||
|
score_run=score_run,
|
||||||
|
included_vote_count=included_vote_count,
|
||||||
|
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored %d legislator topic score rows", written)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
|
||||||
|
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.linked_vote_rows,
|
||||||
|
diagnostics.vote_record_rows,
|
||||||
|
diagnostics.topic_vote_links,
|
||||||
|
diagnostics.scorable_vote_records,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_topic_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No extracted bill topics matched the score scope. Run "
|
||||||
|
"pipelines.tools.extract_bill_topics after bill summarization."
|
||||||
|
)
|
||||||
|
elif diagnostics.linked_vote_rows == 0:
|
||||||
|
logger.warning("No direct substantive text votes matched the score scope.")
|
||||||
|
elif diagnostics.vote_record_rows == 0:
|
||||||
|
logger.warning("No individual vote records matched the score scope.")
|
||||||
|
elif diagnostics.topic_vote_links == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Bill topics exist, but none are attached to bills that have eligible scored votes."
|
||||||
|
)
|
||||||
|
elif diagnostics.scorable_vote_records == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,682 @@
|
|||||||
|
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any, Sequence
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||||
|
|
||||||
|
|
||||||
|
class TopicExtractionError(RuntimeError):
|
||||||
|
"""Raised when a topic extraction request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicCatalog:
|
||||||
|
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]]
|
||||||
|
candidate_topics: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicExtractionDiagnostics:
|
||||||
|
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||||
|
|
||||||
|
bill_rows: int
|
||||||
|
bill_text_rows: int
|
||||||
|
summarized_bill_text_rows: int
|
||||||
|
bills_with_summaries: int
|
||||||
|
bill_topic_rows: int
|
||||||
|
selected_bills: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExtractedBillTopic:
|
||||||
|
"""One extracted bill topic and yes-vote stance."""
|
||||||
|
|
||||||
|
topic: str
|
||||||
|
support_position: BillTopicPosition
|
||||||
|
confidence: float | None = None
|
||||||
|
evidence: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||||
|
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||||
|
for bill_text in bill.bill_texts:
|
||||||
|
if bill_text.summary and bill_text.summary.strip():
|
||||||
|
return bill_text
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_topic_label(value: str) -> str:
|
||||||
|
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||||
|
normalized = value.strip().strip("\"'")
|
||||||
|
normalized = normalized.strip().rstrip(".").strip()
|
||||||
|
return re.sub(r"\s+", " ", normalized).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||||
|
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||||
|
topics_path = path or DEFAULT_TOPICS_PATH
|
||||||
|
try:
|
||||||
|
raw = json.loads(topics_path.read_text())
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
msg = f"Topic catalog not found: {topics_path}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]] = {}
|
||||||
|
candidate_topics: list[str] = []
|
||||||
|
seen_topics: set[str] = set()
|
||||||
|
|
||||||
|
for category, topics in raw.items():
|
||||||
|
if not isinstance(category, str) or not category.strip():
|
||||||
|
msg = "Topic catalog category names must be non-empty strings"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = f"Topic catalog category {category!r} must contain a list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
normalized_topics: list[str] = []
|
||||||
|
for topic in topics:
|
||||||
|
if not isinstance(topic, str):
|
||||||
|
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
normalized_topic = normalize_topic_label(topic)
|
||||||
|
if not normalized_topic:
|
||||||
|
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if normalized_topic in seen_topics:
|
||||||
|
continue
|
||||||
|
seen_topics.add(normalized_topic)
|
||||||
|
normalized_topics.append(normalized_topic)
|
||||||
|
candidate_topics.append(normalized_topic)
|
||||||
|
|
||||||
|
topics_by_category[category.strip()] = normalized_topics
|
||||||
|
|
||||||
|
return TopicCatalog(
|
||||||
|
topics_by_category=topics_by_category,
|
||||||
|
candidate_topics=candidate_topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_topic_extraction_messages(
|
||||||
|
*,
|
||||||
|
bill: Bill,
|
||||||
|
bill_text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build GPT messages for extracting a bill's scored topics."""
|
||||||
|
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||||
|
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||||
|
metadata = "\n".join(
|
||||||
|
(
|
||||||
|
f"Congress: {bill.congress}",
|
||||||
|
f"Bill: {bill.bill_type} {bill.number}",
|
||||||
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||||
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You extract policy topics from U.S. congressional bills.\n"
|
||||||
|
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||||
|
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||||
|
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||||
|
"Select only topics from the provided candidate topic list.\n"
|
||||||
|
"Omit topics that are not materially addressed by the bill.\n"
|
||||||
|
"Return strict JSON only, with this shape:\n"
|
||||||
|
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||||
|
)
|
||||||
|
user_prompt = "\n\n".join(
|
||||||
|
(
|
||||||
|
"BILL METADATA:",
|
||||||
|
metadata,
|
||||||
|
"CANDIDATE TOPICS:",
|
||||||
|
candidate_list,
|
||||||
|
"BILL TEXT:",
|
||||||
|
bill_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_topic_extraction(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
openai_config.openai_chat_completions_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {openai_config.api_key}",
|
||||||
|
"OpenAI-Project": openai_config.openai_project_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": "gpt-5.4-mini",
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=openai_config.timeout_seconds,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||||
|
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||||
|
payload = _load_json_response(response_text)
|
||||||
|
topics = payload.get("topics")
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = "Topic extraction response must contain a topics list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||||
|
for item in topics:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
msg = "Topic extraction response topics must be objects"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_topic = _extract_topic_label(item)
|
||||||
|
topic = normalize_topic_label(raw_topic)
|
||||||
|
if not topic:
|
||||||
|
msg = "Topic extraction response topic must not be blank"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_position = item.get("support_position")
|
||||||
|
try:
|
||||||
|
support_position = BillTopicPosition(raw_position)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = f"Invalid support_position: {raw_position!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
confidence = _parse_confidence(item.get("confidence"))
|
||||||
|
evidence = item.get("evidence")
|
||||||
|
if evidence is not None and not isinstance(evidence, str):
|
||||||
|
evidence = str(evidence)
|
||||||
|
|
||||||
|
extracted = ExtractedBillTopic(
|
||||||
|
topic=topic,
|
||||||
|
support_position=support_position,
|
||||||
|
confidence=confidence,
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
key = (topic, support_position)
|
||||||
|
existing = deduped.get(key)
|
||||||
|
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||||
|
deduped[key] = extracted
|
||||||
|
|
||||||
|
return list(deduped.values())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_topics_for_bill_text(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
bill: Bill,
|
||||||
|
text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[ExtractedBillTopic]:
|
||||||
|
"""Extract accepted catalog topics for a bill text string."""
|
||||||
|
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||||
|
messages = build_topic_extraction_messages(
|
||||||
|
bill=bill,
|
||||||
|
bill_text=text,
|
||||||
|
candidate_topics=sorted(normalized_candidates),
|
||||||
|
)
|
||||||
|
response_text = call_openai_topic_extraction(
|
||||||
|
openai_config=openai_config,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
extracted_topics = parse_topic_extraction_response(response_text)
|
||||||
|
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_topic_result(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
bill: Bill,
|
||||||
|
topics: Sequence[ExtractedBillTopic],
|
||||||
|
replace_existing: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Store extracted topics for one bill."""
|
||||||
|
if replace_existing:
|
||||||
|
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
session.add(
|
||||||
|
BillTopic(
|
||||||
|
bill_id=bill.id,
|
||||||
|
topic=normalize_topic_label(topic.topic),
|
||||||
|
support_position=topic.support_position,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bills_for_topic_extraction(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select[tuple[Bill]]:
|
||||||
|
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
*has_summary,
|
||||||
|
]
|
||||||
|
if with_votes_only:
|
||||||
|
summarized_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||||
|
stmt = (
|
||||||
|
select(Bill)
|
||||||
|
.where(summarized_text_exists)
|
||||||
|
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||||
|
.order_by(Bill.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
selected_text_exists = exists(
|
||||||
|
select(BillText.id).where(
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
BillText.id.in_(bill_text_ids),
|
||||||
|
*summarized_text_filters[1:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stmt = stmt.where(selected_text_exists)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||||
|
)
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_topic_extraction_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> TopicExtractionDiagnostics:
|
||||||
|
"""Count topic extraction inputs for explaining empty selections."""
|
||||||
|
bill_filters = []
|
||||||
|
bill_text_filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
bill_filters.append(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
bill_filters.append(Bill.id.in_(bill_ids))
|
||||||
|
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
bill_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summary_filters = [*bill_text_filters, *has_summary]
|
||||||
|
|
||||||
|
bills_with_summaries = session.scalar(
|
||||||
|
select(func.count(func.distinct(Bill.id)))
|
||||||
|
.select_from(Bill)
|
||||||
|
.join(BillText, BillText.bill_id == Bill.id)
|
||||||
|
.where(*bill_filters, *summary_filters)
|
||||||
|
)
|
||||||
|
selected_bills = session.scalar(
|
||||||
|
select(func.count()).select_from(
|
||||||
|
create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
).subquery()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TopicExtractionDiagnostics(
|
||||||
|
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||||
|
bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=bill_text_filters,
|
||||||
|
),
|
||||||
|
summarized_bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=summary_filters,
|
||||||
|
),
|
||||||
|
bills_with_summaries=bills_with_summaries or 0,
|
||||||
|
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||||
|
selected_bills=selected_bills or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||||
|
text = response_text.strip()
|
||||||
|
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
text = fenced.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
msg = "Topic extraction response must be a JSON object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_confidence(raw: Any) -> float | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
msg = f"Invalid confidence: {raw!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||||
|
if topic.confidence is None:
|
||||||
|
return (0, 0.0)
|
||||||
|
return (1, topic.confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||||
|
raw_topic = item.get("topic")
|
||||||
|
if isinstance(raw_topic, str):
|
||||||
|
return raw_topic
|
||||||
|
if isinstance(raw_topic, dict):
|
||||||
|
for key in ("topic", "label", "name", "title"):
|
||||||
|
value = raw_topic.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
msg = "Topic extraction response topic must be a string"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_bill_texts(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
bill_filters: Sequence[ColumnElement[bool]],
|
||||||
|
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||||
|
) -> int:
|
||||||
|
stmt = select(func.count(BillText.id))
|
||||||
|
if bill_filters:
|
||||||
|
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||||
|
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
topics_path: Annotated[
|
||||||
|
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||||
|
] = DEFAULT_TOPICS_PATH,
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||||
|
] = False,
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts before processing."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill topics."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_catalog = load_topic_catalog(topics_path)
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d candidate topics from %s",
|
||||||
|
len(topic_catalog.candidate_topics),
|
||||||
|
topics_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
if diagnose or dry_run:
|
||||||
|
diagnostics = collect_topic_extraction_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
_log_topic_extraction_diagnostics(diagnostics)
|
||||||
|
if dry_run:
|
||||||
|
return
|
||||||
|
|
||||||
|
openai_config = get_openai_config()
|
||||||
|
|
||||||
|
stmt = create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bills = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
failed = 0
|
||||||
|
for index, bill in enumerate(bills, 1):
|
||||||
|
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||||
|
if bill_text is None:
|
||||||
|
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||||
|
continue
|
||||||
|
summary = bill_text.summary.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_topics = extract_topics_for_bill_text(
|
||||||
|
openai_config=openai_config,
|
||||||
|
bill=bill,
|
||||||
|
text=summary,
|
||||||
|
candidate_topics=topic_catalog.candidate_topics,
|
||||||
|
)
|
||||||
|
except (httpx.HTTPError, TopicExtractionError):
|
||||||
|
failed += 1
|
||||||
|
logger.exception(
|
||||||
|
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
store_bill_topic_result(
|
||||||
|
session=session,
|
||||||
|
bill=bill,
|
||||||
|
topics=extracted_topics,
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
written += 1
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Stored %d topics for bill id=%s",
|
||||||
|
len(extracted_topics),
|
||||||
|
bill.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Done: stored topic results for %d bills; failed %d bills",
|
||||||
|
written,
|
||||||
|
failed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_topic_extraction_diagnostics(
|
||||||
|
diagnostics: TopicExtractionDiagnostics,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||||
|
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||||
|
"bill_topic_rows=%d selected_bills=%d",
|
||||||
|
diagnostics.bill_rows,
|
||||||
|
diagnostics.bill_text_rows,
|
||||||
|
diagnostics.summarized_bill_text_rows,
|
||||||
|
diagnostics.bills_with_summaries,
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.selected_bills,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_rows == 0:
|
||||||
|
logger.warning("No bills matched the topic extraction scope.")
|
||||||
|
elif diagnostics.bill_text_rows == 0:
|
||||||
|
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||||
|
elif diagnostics.summarized_bill_text_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No summarized bill_text rows matched the topic extraction scope. "
|
||||||
|
"Run pipelines.tools.summarize_bills first."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||||
|
logger.warning(
|
||||||
|
"No bills selected because matching bills already have topics. "
|
||||||
|
"Use --force to regenerate them."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0:
|
||||||
|
logger.warning("No bills selected for topic extraction.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import Select, exists, or_, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from tiktoken import get_encoding
|
||||||
|
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def load_summarization_prompts(
|
||||||
|
section: str = "summarization",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
|
|
||||||
|
return tomllib.loads(summarization_prompts.read_text())[section]
|
||||||
|
|
||||||
|
|
||||||
|
class BillSummaryError(RuntimeError):
|
||||||
|
"""Raised when a bill summary request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_summary(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
msg = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
logger.info(f"{response.text=}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def build_bill_summary_messages(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build the GPT prompt messages plus compressed text and user prompt."""
|
||||||
|
if not bill_text.text_content:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
compressed_text = compress_bill_text(bill_text.text_content)
|
||||||
|
if not compressed_text:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
user_prompt = summarization_prompts["user_template"].format(
|
||||||
|
text_content=compressed_text
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
|
||||||
|
logger.info(f"{user_prompt_tokens=}")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return messages, user_prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bill_text(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> str:
|
||||||
|
"""Generate and return a summary for one bill_text row."""
|
||||||
|
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
# This may only be for gpt-5.4 mini I need to read the docs
|
||||||
|
if user_prompt_tokens > 272000:
|
||||||
|
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
|
||||||
|
logger.warning(msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
summary = call_openai_summary(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
).strip()
|
||||||
|
if not summary:
|
||||||
|
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_summary_result(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summary: str,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||||
|
bill_text.summary = summary
|
||||||
|
bill_text.summarization_model = model
|
||||||
|
bill_text.summarization_system_prompt_version = "v1.2"
|
||||||
|
bill_text.summarization_user_prompt_version = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bill_texts_for_summarization(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Select bill_text rows that have source text and need summaries."""
|
||||||
|
stmt = (
|
||||||
|
select(BillText)
|
||||||
|
.join(Bill, Bill.id == BillText.bill_id)
|
||||||
|
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||||
|
.options(selectinload(BillText.bill))
|
||||||
|
.order_by(BillText.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill summaries."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
if not getenv("CLOSEDAI_TOKEN"):
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
summarization_prompts = load_summarization_prompts()
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = create_select_bill_texts_for_summarization(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bill_texts = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for index, bill_text in enumerate(bill_texts, 1):
|
||||||
|
summary = summarize_bill_text(
|
||||||
|
model=model,
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
if summary is None:
|
||||||
|
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||||
|
continue
|
||||||
|
store_bill_summary_result(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summary=summary,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
written += 1
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||||
|
|
||||||
|
logger.info("Done: stored %d summaries", written)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
+14
-3
@@ -17,6 +17,10 @@ NAMING_CONVENTION = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSetupError(RuntimeError):
|
||||||
|
"""Raised when database configuration is missing or invalid."""
|
||||||
|
|
||||||
|
|
||||||
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
||||||
"""Get connection info from environment variables."""
|
"""Get connection info from environment variables."""
|
||||||
database = getenv(f"{name}_DB")
|
database = getenv(f"{name}_DB")
|
||||||
@@ -27,11 +31,18 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
|||||||
|
|
||||||
if None in (database, host, port, username):
|
if None in (database, host, port, username):
|
||||||
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
||||||
raise ValueError(error)
|
raise DatabaseSetupError(error)
|
||||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
return cast(
|
||||||
|
"tuple[str, str, str, str, str | None]",
|
||||||
|
(database, host, port, username, password),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
def get_postgres_engine(
|
||||||
|
*,
|
||||||
|
name: str = "POSTGRES",
|
||||||
|
pool_pre_ping: bool = True,
|
||||||
|
) -> Engine:
|
||||||
"""Create a SQLAlchemy engine from environment variables."""
|
"""Create a SQLAlchemy engine from environment variables."""
|
||||||
database, host, port, username, password = get_connection_info(name)
|
database, host, port, username, password = get_connection_info(name)
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,86 @@
|
|||||||
"""init."""
|
"""Congress ORM models."""
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.congress.bill import Bill, BillText
|
from pipelines.orm.data_science_dev.congress.bill import (
|
||||||
|
Bill,
|
||||||
|
BillAction,
|
||||||
|
BillActionRecordedVote,
|
||||||
|
BillRelation,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
)
|
||||||
|
from pipelines.orm.data_science_dev.congress.amendment import (
|
||||||
|
Amendment,
|
||||||
|
AmendmentAction,
|
||||||
|
AmendmentActionRecordedVote,
|
||||||
|
)
|
||||||
|
from pipelines.orm.data_science_dev.congress.context import (
|
||||||
|
ClassificationMethod,
|
||||||
|
ConfidenceLevel,
|
||||||
|
IngestRun,
|
||||||
|
MeasureFunction,
|
||||||
|
MeasureSubtype,
|
||||||
|
ScoreRun,
|
||||||
|
SourceArtifact,
|
||||||
|
SubjectType,
|
||||||
|
TextResolutionMethod,
|
||||||
|
TextTargetBasis,
|
||||||
|
TextTargetType,
|
||||||
|
VoteActionMatch,
|
||||||
|
VoteActionScope,
|
||||||
|
VoteClassification,
|
||||||
|
VoteContextAudit,
|
||||||
|
VoteEffect,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VoteMeasureRole,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.congress.legislator import (
|
from pipelines.orm.data_science_dev.congress.legislator import (
|
||||||
Legislator,
|
Legislator,
|
||||||
|
LegislatorScore,
|
||||||
LegislatorSocialMedia,
|
LegislatorSocialMedia,
|
||||||
|
LegislatorScoreFake,
|
||||||
)
|
)
|
||||||
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"Amendment",
|
||||||
|
"AmendmentAction",
|
||||||
|
"AmendmentActionRecordedVote",
|
||||||
"Bill",
|
"Bill",
|
||||||
|
"BillAction",
|
||||||
|
"BillActionRecordedVote",
|
||||||
|
"BillRelation",
|
||||||
"BillText",
|
"BillText",
|
||||||
|
"BillTopic",
|
||||||
|
"BillTopicPosition",
|
||||||
|
"ClassificationMethod",
|
||||||
|
"ConfidenceLevel",
|
||||||
|
"IngestRun",
|
||||||
"Legislator",
|
"Legislator",
|
||||||
|
"LegislatorScore",
|
||||||
|
"LegislatorScoreFake",
|
||||||
"LegislatorSocialMedia",
|
"LegislatorSocialMedia",
|
||||||
|
"MeasureFunction",
|
||||||
|
"MeasureSubtype",
|
||||||
|
"ScoreRun",
|
||||||
|
"SourceArtifact",
|
||||||
|
"SubjectType",
|
||||||
|
"TextResolutionMethod",
|
||||||
|
"TextTargetBasis",
|
||||||
|
"TextTargetType",
|
||||||
"Vote",
|
"Vote",
|
||||||
|
"VoteActionMatch",
|
||||||
|
"VoteActionScope",
|
||||||
|
"VoteClassification",
|
||||||
|
"VoteContextAudit",
|
||||||
|
"VoteEffect",
|
||||||
|
"VoteMeasureLink",
|
||||||
|
"VoteMeasureRole",
|
||||||
|
"VotePositionMeaning",
|
||||||
|
"VoteRelationship",
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
|
"VoteTextTarget",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""Amendment models and official action context."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
|
||||||
|
class Amendment(DataScienceDevTableBase):
|
||||||
|
"""Congressional amendment linked to a bill or to another amendment."""
|
||||||
|
|
||||||
|
__tablename__ = "amendment"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"amendment_type",
|
||||||
|
"number",
|
||||||
|
name="uq_amendment_congress_type_number",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
congress: Mapped[int]
|
||||||
|
amendment_type: Mapped[str]
|
||||||
|
number: Mapped[int]
|
||||||
|
chamber: Mapped[str]
|
||||||
|
description: Mapped[str | None]
|
||||||
|
purpose: Mapped[str | None]
|
||||||
|
amended_bill_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.bill.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
amended_amendment_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.amendment.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
source_path: Mapped[str | None]
|
||||||
|
source_artifact_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.source_artifact.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
|
actions: Mapped[list[AmendmentAction]] = relationship(
|
||||||
|
"AmendmentAction",
|
||||||
|
back_populates="amendment",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
amended_amendment: Mapped[Amendment | None] = relationship(
|
||||||
|
"Amendment",
|
||||||
|
remote_side="Amendment.id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AmendmentAction(DataScienceDevTableBase):
|
||||||
|
"""Official action row for an amendment."""
|
||||||
|
|
||||||
|
__tablename__ = "amendment_action"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"amendment_id",
|
||||||
|
"sequence",
|
||||||
|
name="uq_amendment_action_amendment_id_sequence",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
amendment_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.amendment.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
sequence: Mapped[int]
|
||||||
|
action_date: Mapped[date]
|
||||||
|
action_time: Mapped[str | None]
|
||||||
|
action_text: Mapped[str]
|
||||||
|
action_type: Mapped[str | None]
|
||||||
|
action_code: Mapped[str | None]
|
||||||
|
source_system_code: Mapped[str | None]
|
||||||
|
source_system_name: Mapped[str | None]
|
||||||
|
source_artifact_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.source_artifact.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
|
amendment: Mapped[Amendment] = relationship(
|
||||||
|
"Amendment",
|
||||||
|
back_populates="actions",
|
||||||
|
)
|
||||||
|
recorded_votes: Mapped[list[AmendmentActionRecordedVote]] = relationship(
|
||||||
|
"AmendmentActionRecordedVote",
|
||||||
|
back_populates="amendment_action",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AmendmentActionRecordedVote(DataScienceDevTableBase):
|
||||||
|
"""Recorded vote nested under one official amendment action."""
|
||||||
|
|
||||||
|
__tablename__ = "amendment_action_recorded_vote"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"amendment_action_id",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
name="uq_amendment_action_recorded_vote_match_key",
|
||||||
|
),
|
||||||
|
Index(
|
||||||
|
"ix_amendment_action_recorded_vote_match_tuple",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
amendment_action_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.amendment_action.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
congress: Mapped[int]
|
||||||
|
chamber: Mapped[str]
|
||||||
|
session_number: Mapped[int]
|
||||||
|
roll_number: Mapped[int]
|
||||||
|
vote_datetime: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
vote_url: Mapped[str | None]
|
||||||
|
|
||||||
|
amendment_action: Mapped[AmendmentAction] = relationship(
|
||||||
|
"AmendmentAction",
|
||||||
|
back_populates="recorded_votes",
|
||||||
|
)
|
||||||
@@ -1,23 +1,48 @@
|
|||||||
"""Bill model - legislation introduced in Congress."""
|
"""Bill models for legislation, official actions, text versions, and topic tags."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date
|
from datetime import date, datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
from sqlalchemy import DateTime, Enum, ForeignKey, Index, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pipelines.orm.data_science_dev.congress.vote import Vote
|
from pipelines.orm.data_science_dev.congress.context import VoteMeasureLink
|
||||||
|
|
||||||
|
|
||||||
|
class BillTopicPosition(StrEnum):
|
||||||
|
"""Whether a yes vote on a bill is for or against a topic."""
|
||||||
|
|
||||||
|
FOR = "for"
|
||||||
|
AGAINST = "against"
|
||||||
|
|
||||||
|
|
||||||
|
def _enum_column(enum_cls: type[StrEnum], *, name: str) -> Enum:
|
||||||
|
"""Build a portable SQLAlchemy enum column for StrEnum values."""
|
||||||
|
|
||||||
|
return Enum(
|
||||||
|
enum_cls,
|
||||||
|
values_callable=lambda enum_type: [member.value for member in enum_type],
|
||||||
|
native_enum=False,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Bill(DataScienceDevTableBase):
|
class Bill(DataScienceDevTableBase):
|
||||||
"""Legislation with congress number, type, titles, status, and sponsor."""
|
"""Legislation with congress number, type, titles, status, and sponsor."""
|
||||||
|
|
||||||
__tablename__ = "bill"
|
__tablename__ = "bill"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
||||||
|
),
|
||||||
|
Index("ix_bill_congress", "congress"),
|
||||||
|
)
|
||||||
|
|
||||||
congress: Mapped[int]
|
congress: Mapped[int]
|
||||||
bill_type: Mapped[str]
|
bill_type: Mapped[str]
|
||||||
@@ -33,22 +58,39 @@ class Bill(DataScienceDevTableBase):
|
|||||||
sponsor_bioguide_id: Mapped[str | None]
|
sponsor_bioguide_id: Mapped[str | None]
|
||||||
|
|
||||||
subjects_top_term: Mapped[str | None]
|
subjects_top_term: Mapped[str | None]
|
||||||
|
score_processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
|
||||||
votes: Mapped[list[Vote]] = relationship(
|
|
||||||
"Vote",
|
|
||||||
back_populates="bill",
|
|
||||||
)
|
|
||||||
bill_texts: Mapped[list[BillText]] = relationship(
|
bill_texts: Mapped[list[BillText]] = relationship(
|
||||||
"BillText",
|
"BillText",
|
||||||
back_populates="bill",
|
back_populates="bill",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
topics: Mapped[list[BillTopic]] = relationship(
|
||||||
__table_args__ = (
|
"BillTopic",
|
||||||
UniqueConstraint(
|
back_populates="bill",
|
||||||
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
cascade="all, delete-orphan",
|
||||||
),
|
)
|
||||||
Index("ix_bill_congress", "congress"),
|
bill_actions: Mapped[list[BillAction]] = relationship(
|
||||||
|
"BillAction",
|
||||||
|
back_populates="bill",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
outgoing_bill_relations: Mapped[list[BillRelation]] = relationship(
|
||||||
|
"BillRelation",
|
||||||
|
foreign_keys="BillRelation.bill_id",
|
||||||
|
back_populates="bill",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
incoming_bill_relations: Mapped[list[BillRelation]] = relationship(
|
||||||
|
"BillRelation",
|
||||||
|
foreign_keys="BillRelation.related_bill_id",
|
||||||
|
back_populates="related_bill",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
vote_measure_links: Mapped[list[VoteMeasureLink]] = relationship(
|
||||||
|
"VoteMeasureLink",
|
||||||
|
back_populates="measure",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,17 +98,147 @@ class BillText(DataScienceDevTableBase):
|
|||||||
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
||||||
|
|
||||||
__tablename__ = "bill_text"
|
__tablename__ = "bill_text"
|
||||||
|
|
||||||
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
|
||||||
version_code: Mapped[str]
|
|
||||||
version_name: Mapped[str | None]
|
|
||||||
text_content: Mapped[str | None]
|
|
||||||
date: Mapped[date | None]
|
|
||||||
|
|
||||||
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
version_code: Mapped[str]
|
||||||
|
version_name: Mapped[str | None]
|
||||||
|
text_content: Mapped[str | None]
|
||||||
|
summary: Mapped[str | None]
|
||||||
|
summarization_model: Mapped[str | None]
|
||||||
|
summarization_user_prompt_version: Mapped[str | None]
|
||||||
|
summarization_system_prompt_version: Mapped[str | None]
|
||||||
|
date: Mapped[date | None]
|
||||||
|
source_datetime_raw: Mapped[str | None]
|
||||||
|
text_url_xml: Mapped[str | None]
|
||||||
|
text_url_pdf: Mapped[str | None]
|
||||||
|
text_url_html: Mapped[str | None]
|
||||||
|
source_artifact_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.source_artifact.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
||||||
|
|
||||||
|
|
||||||
|
class BillAction(DataScienceDevTableBase):
|
||||||
|
"""Official action row from Bill Status XML."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_action"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("bill_id", "sequence", name="uq_bill_action_bill_id_sequence"),
|
||||||
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
sequence: Mapped[int]
|
||||||
|
action_date: Mapped[date]
|
||||||
|
action_time: Mapped[str | None]
|
||||||
|
action_text: Mapped[str]
|
||||||
|
action_type: Mapped[str | None]
|
||||||
|
action_code: Mapped[str | None]
|
||||||
|
source_system_code: Mapped[str | None]
|
||||||
|
source_system_name: Mapped[str | None]
|
||||||
|
source_artifact_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.source_artifact.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_actions")
|
||||||
|
recorded_votes: Mapped[list[BillActionRecordedVote]] = relationship(
|
||||||
|
"BillActionRecordedVote",
|
||||||
|
back_populates="bill_action",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillActionRecordedVote(DataScienceDevTableBase):
|
||||||
|
"""Recorded vote nested under one official bill action."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_action_recorded_vote"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"bill_action_id",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
name="uq_bill_action_recorded_vote_match_key",
|
||||||
|
),
|
||||||
|
Index(
|
||||||
|
"ix_bill_action_recorded_vote_match_tuple",
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
bill_action_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.bill_action.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
congress: Mapped[int]
|
||||||
|
chamber: Mapped[str]
|
||||||
|
session_number: Mapped[int]
|
||||||
|
roll_number: Mapped[int]
|
||||||
|
vote_datetime: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
vote_url: Mapped[str | None]
|
||||||
|
|
||||||
|
bill_action: Mapped[BillAction] = relationship(
|
||||||
|
"BillAction",
|
||||||
|
back_populates="recorded_votes",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillRelation(DataScienceDevTableBase):
|
||||||
|
"""Relationship between one bill/resolution and another."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_relation"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_bill_relation_bill_id", "bill_id"),
|
||||||
|
Index("ix_bill_relation_related_bill_id", "related_bill_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
related_bill_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.bill.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
relationship_type: Mapped[str]
|
||||||
|
identified_by: Mapped[str | None]
|
||||||
|
latest_action_date: Mapped[date | None]
|
||||||
|
latest_action_text: Mapped[str | None]
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship(
|
||||||
|
"Bill",
|
||||||
|
foreign_keys=[bill_id],
|
||||||
|
back_populates="outgoing_bill_relations",
|
||||||
|
)
|
||||||
|
related_bill: Mapped[Bill] = relationship(
|
||||||
|
"Bill",
|
||||||
|
foreign_keys=[related_bill_id],
|
||||||
|
back_populates="incoming_bill_relations",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillTopic(DataScienceDevTableBase):
|
||||||
|
"""One bill stance on one topic used to score roll-call votes."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_topic"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"bill_id",
|
||||||
|
"topic",
|
||||||
|
"support_position",
|
||||||
|
name="uq_bill_topic_bill_id_topic_support_position",
|
||||||
|
),
|
||||||
|
Index("ix_bill_topic_topic", "topic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
topic: Mapped[str]
|
||||||
|
support_position: Mapped[BillTopicPosition] = mapped_column(
|
||||||
|
_enum_column(BillTopicPosition, name="bill_topic_position")
|
||||||
|
)
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="topics")
|
||||||
|
|||||||
@@ -0,0 +1,462 @@
|
|||||||
|
"""Canonical vote context, artifact tracking, and run metadata models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, ForeignKey, Index, func, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pipelines.orm.data_science_dev.congress.amendment import Amendment, AmendmentAction
|
||||||
|
from pipelines.orm.data_science_dev.congress.bill import Bill, BillAction, BillText
|
||||||
|
from pipelines.orm.data_science_dev.congress.legislator import LegislatorScore
|
||||||
|
from pipelines.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
def _enum_column(enum_cls: type[StrEnum], *, name: str) -> Enum:
|
||||||
|
"""Build a portable SQLAlchemy enum column for StrEnum values."""
|
||||||
|
|
||||||
|
return Enum(
|
||||||
|
enum_cls,
|
||||||
|
values_callable=lambda enum_type: [member.value for member in enum_type],
|
||||||
|
native_enum=False,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfidenceLevel(StrEnum):
|
||||||
|
"""Low/medium/high confidence buckets."""
|
||||||
|
|
||||||
|
HIGH = "high"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
LOW = "low"
|
||||||
|
|
||||||
|
|
||||||
|
class VoteActionScope(StrEnum):
|
||||||
|
"""Whether a matched action came from bill or amendment context."""
|
||||||
|
|
||||||
|
BILL = "bill"
|
||||||
|
AMENDMENT = "amendment"
|
||||||
|
|
||||||
|
|
||||||
|
class SubjectType(StrEnum):
|
||||||
|
"""The direct legal/procedural subject of the vote."""
|
||||||
|
|
||||||
|
MEASURE = "measure"
|
||||||
|
AMENDMENT = "amendment"
|
||||||
|
NOMINATION = "nomination"
|
||||||
|
TREATY = "treaty"
|
||||||
|
QUORUM = "quorum"
|
||||||
|
CHAMBER_ADMIN = "chamber_admin"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class MeasureSubtype(StrEnum):
|
||||||
|
"""Formal congressional measure subtype."""
|
||||||
|
|
||||||
|
BILL = "bill"
|
||||||
|
JOINT_RESOLUTION = "joint_resolution"
|
||||||
|
CONCURRENT_RESOLUTION = "concurrent_resolution"
|
||||||
|
SIMPLE_RESOLUTION = "simple_resolution"
|
||||||
|
|
||||||
|
|
||||||
|
class MeasureFunction(StrEnum):
|
||||||
|
"""Semantic function of a measure beyond its formal subtype."""
|
||||||
|
|
||||||
|
SUBSTANTIVE_MEASURE = "substantive_measure"
|
||||||
|
SPECIAL_RULE = "special_rule"
|
||||||
|
BUDGET_RESOLUTION = "budget_resolution"
|
||||||
|
CHAMBER_INTERNAL = "chamber_internal"
|
||||||
|
COMMEMORATIVE_OR_SENSE_OF = "commemorative_or_sense_of"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class VoteRelationship(StrEnum):
|
||||||
|
"""The vote's relationship to the direct subject and its text."""
|
||||||
|
|
||||||
|
DIRECT_TEXT_VOTE = "direct_text_vote"
|
||||||
|
AMENDMENT_TEXT_VOTE = "amendment_text_vote"
|
||||||
|
PROCEDURAL_RELATED_TO_MEASURE = "procedural_related_to_measure"
|
||||||
|
PROCEDURAL_RELATED_TO_AMENDMENT = "procedural_related_to_amendment"
|
||||||
|
NON_LEGISLATIVE = "non_legislative"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationMethod(StrEnum):
|
||||||
|
"""How the final classification was derived."""
|
||||||
|
|
||||||
|
RECORDED_VOTE_ACTION_EXACT = "recorded_vote_action_exact"
|
||||||
|
RECORDED_VOTE_ACTION_DUPLICATE_SOURCE_DEDUPED = (
|
||||||
|
"recorded_vote_action_duplicate_source_deduped"
|
||||||
|
)
|
||||||
|
VOTE_XML_ONLY = "vote_xml_only"
|
||||||
|
QUESTION_TEXT_ONLY = "question_text_only"
|
||||||
|
MANUAL_REVIEW = "manual_review"
|
||||||
|
|
||||||
|
|
||||||
|
class VoteMeasureRole(StrEnum):
|
||||||
|
"""How one measure relates to one classified vote."""
|
||||||
|
|
||||||
|
VOTED_ON = "voted_on"
|
||||||
|
RULE_FOR = "rule_for"
|
||||||
|
UNDERLYING_BILL = "underlying_bill"
|
||||||
|
PROCEDURAL_TARGET = "procedural_target"
|
||||||
|
AMENDS = "amends"
|
||||||
|
AMENDED_BY = "amended_by"
|
||||||
|
CONFERENCE_REPORT_FOR = "conference_report_for"
|
||||||
|
RELATED_ONLY = "related_only"
|
||||||
|
|
||||||
|
|
||||||
|
class TextTargetType(StrEnum):
|
||||||
|
"""Which kind of legislative text was the object of a vote."""
|
||||||
|
|
||||||
|
BILL_TEXT = "bill_text"
|
||||||
|
RESOLUTION_TEXT = "resolution_text"
|
||||||
|
AMENDMENT_TEXT = "amendment_text"
|
||||||
|
NONE = "none"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TextTargetBasis(StrEnum):
|
||||||
|
"""How the text target should be interpreted."""
|
||||||
|
|
||||||
|
EXACT_ACTION_TEXT_VERSION = "exact_action_text_version"
|
||||||
|
RESULTING_ENGROSSED_VERSION = "resulting_engrossed_version"
|
||||||
|
RECEIVED_PRIOR_CHAMBER_VERSION = "received_prior_chamber_version"
|
||||||
|
AMENDMENT_TEXT = "amendment_text"
|
||||||
|
RULE_RESOLUTION_TEXT = "rule_resolution_text"
|
||||||
|
NO_TEXT_TARGET = "no_text_target"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TextResolutionMethod(StrEnum):
|
||||||
|
"""How the official text target was resolved."""
|
||||||
|
|
||||||
|
TEXT_EXACT_ACTION_DATE_AND_CODE = "text_exact_action_date_and_code"
|
||||||
|
TEXT_EXACT_ACTION_DATE_WRONG_CODE = "text_exact_action_date_wrong_code"
|
||||||
|
TEXT_PRIOR_VERSION_CODE_MATCH = "text_prior_version_code_match"
|
||||||
|
TEXT_RECEIVED_PRIOR_CHAMBER_VERSION = "text_received_prior_chamber_version"
|
||||||
|
TEXT_RESULTING_ENROLLED_ONLY = "text_resulting_enrolled_only"
|
||||||
|
AMENDMENT_TEXT_UNMODELED_PHASE1 = "amendment_text_unmodeled_phase1"
|
||||||
|
NO_TEXT_TARGET = "no_text_target"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class VoteEffect(StrEnum):
|
||||||
|
"""Meaning of one member position relative to the target text/procedure."""
|
||||||
|
|
||||||
|
SUPPORTS_TEXT = "supports_text"
|
||||||
|
OPPOSES_TEXT = "opposes_text"
|
||||||
|
ADVANCES_PROCEDURE = "advances_procedure"
|
||||||
|
BLOCKS_PROCEDURE = "blocks_procedure"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class IngestRun(DataScienceDevTableBase):
|
||||||
|
"""One full ingestion or context rebuild run."""
|
||||||
|
|
||||||
|
__tablename__ = "ingest_run"
|
||||||
|
|
||||||
|
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
git_sha: Mapped[str | None]
|
||||||
|
classifier_version: Mapped[str | None]
|
||||||
|
source_snapshot_label: Mapped[str | None]
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
source_artifacts: Mapped[list[SourceArtifact]] = relationship(
|
||||||
|
"SourceArtifact",
|
||||||
|
back_populates="ingest_run",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
score_runs: Mapped[list[ScoreRun]] = relationship(
|
||||||
|
"ScoreRun",
|
||||||
|
back_populates="ingest_run",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceArtifact(DataScienceDevTableBase):
|
||||||
|
"""Local artifact manifest entry for reproducibility."""
|
||||||
|
|
||||||
|
__tablename__ = "source_artifact"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_source_artifact_source_kind", "source_kind"),
|
||||||
|
Index("ix_source_artifact_congress", "congress"),
|
||||||
|
Index(
|
||||||
|
"uq_source_artifact_ingest_identity",
|
||||||
|
"ingest_run_id",
|
||||||
|
"local_path",
|
||||||
|
"sha256",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
source_kind: Mapped[str]
|
||||||
|
congress: Mapped[int]
|
||||||
|
chamber: Mapped[str | None]
|
||||||
|
local_path: Mapped[str]
|
||||||
|
source_url: Mapped[str | None]
|
||||||
|
sha256: Mapped[str]
|
||||||
|
byte_size: Mapped[int]
|
||||||
|
modified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
ingested_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
ingest_run_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.ingest_run.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
|
ingest_run: Mapped[IngestRun | None] = relationship(
|
||||||
|
"IngestRun",
|
||||||
|
back_populates="source_artifacts",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreRun(DataScienceDevTableBase):
|
||||||
|
"""One full score recomputation tied to one ingest snapshot."""
|
||||||
|
|
||||||
|
__tablename__ = "score_run"
|
||||||
|
|
||||||
|
ingest_run_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.ingest_run.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
classifier_version: Mapped[str | None]
|
||||||
|
scoring_version: Mapped[str | None]
|
||||||
|
included_vote_count: Mapped[int]
|
||||||
|
excluded_vote_count: Mapped[int]
|
||||||
|
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
ingest_run: Mapped[IngestRun | None] = relationship(
|
||||||
|
"IngestRun",
|
||||||
|
back_populates="score_runs",
|
||||||
|
)
|
||||||
|
scores: Mapped[list[LegislatorScore]] = relationship(
|
||||||
|
"LegislatorScore",
|
||||||
|
back_populates="score_run",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VoteActionMatch(DataScienceDevTableBase):
|
||||||
|
"""A candidate or selected official action match for one raw vote."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_action_match"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_vote_action_match_vote_id", "vote_id"),
|
||||||
|
Index(
|
||||||
|
"uq_vote_action_match_selected_vote_id",
|
||||||
|
"vote_id",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("is_selected"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(ForeignKey("main.vote.id", ondelete="CASCADE"))
|
||||||
|
action_scope: Mapped[VoteActionScope] = mapped_column(
|
||||||
|
_enum_column(VoteActionScope, name="vote_action_scope")
|
||||||
|
)
|
||||||
|
bill_action_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.bill_action.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
amendment_action_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.amendment_action.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
is_selected: Mapped[bool]
|
||||||
|
match_method: Mapped[str]
|
||||||
|
match_reason: Mapped[str | None]
|
||||||
|
match_confidence: Mapped[ConfidenceLevel] = mapped_column(
|
||||||
|
_enum_column(ConfidenceLevel, name="vote_action_match_confidence")
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="action_matches")
|
||||||
|
bill_action: Mapped[BillAction | None] = relationship("BillAction")
|
||||||
|
amendment_action: Mapped[AmendmentAction | None] = relationship("AmendmentAction")
|
||||||
|
|
||||||
|
|
||||||
|
class VoteClassification(DataScienceDevTableBase):
|
||||||
|
"""Normalized classification for what a vote was legally/procedurally on."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_classification"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_vote_classification_subject_type", "subject_type"),
|
||||||
|
Index(
|
||||||
|
"ix_vote_classification_eligible_vote_id",
|
||||||
|
"vote_id",
|
||||||
|
postgresql_where=text(
|
||||||
|
"subject_type = 'measure' "
|
||||||
|
"AND vote_relationship = 'direct_text_vote' "
|
||||||
|
"AND is_direct_vote_on_legislative_text "
|
||||||
|
"AND is_substantive_policy_vote "
|
||||||
|
"AND NOT is_special_rule"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
subject_type: Mapped[SubjectType] = mapped_column(
|
||||||
|
_enum_column(SubjectType, name="vote_subject_type")
|
||||||
|
)
|
||||||
|
measure_type: Mapped[str | None]
|
||||||
|
measure_subtype: Mapped[MeasureSubtype | None] = mapped_column(
|
||||||
|
_enum_column(MeasureSubtype, name="vote_measure_subtype")
|
||||||
|
)
|
||||||
|
measure_function: Mapped[MeasureFunction | None] = mapped_column(
|
||||||
|
_enum_column(MeasureFunction, name="vote_measure_function")
|
||||||
|
)
|
||||||
|
vote_relationship: Mapped[VoteRelationship] = mapped_column(
|
||||||
|
_enum_column(VoteRelationship, name="vote_relationship")
|
||||||
|
)
|
||||||
|
is_legislation_related: Mapped[bool]
|
||||||
|
is_direct_vote_on_legislative_text: Mapped[bool]
|
||||||
|
is_substantive_policy_vote: Mapped[bool]
|
||||||
|
is_lawmaking_vehicle: Mapped[bool]
|
||||||
|
is_special_rule: Mapped[bool]
|
||||||
|
classification_method: Mapped[ClassificationMethod] = mapped_column(
|
||||||
|
_enum_column(ClassificationMethod, name="vote_classification_method")
|
||||||
|
)
|
||||||
|
classification_confidence_reason: Mapped[str | None]
|
||||||
|
confidence: Mapped[ConfidenceLevel] = mapped_column(
|
||||||
|
_enum_column(ConfidenceLevel, name="vote_classification_confidence")
|
||||||
|
)
|
||||||
|
classified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
classification_version: Mapped[str]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="classification")
|
||||||
|
|
||||||
|
|
||||||
|
class VoteMeasureLink(DataScienceDevTableBase):
|
||||||
|
"""Relationship between a classified vote and one bill/resolution measure."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_measure_link"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_vote_measure_link_vote_id", "vote_id"),
|
||||||
|
Index("ix_vote_measure_link_vote_id_role", "vote_id", "role"),
|
||||||
|
Index("ix_vote_measure_link_measure_id_role", "measure_id", "role"),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(ForeignKey("main.vote.id", ondelete="CASCADE"))
|
||||||
|
measure_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
role: Mapped[VoteMeasureRole] = mapped_column(
|
||||||
|
_enum_column(VoteMeasureRole, name="vote_measure_role")
|
||||||
|
)
|
||||||
|
source: Mapped[str]
|
||||||
|
confidence: Mapped[ConfidenceLevel] = mapped_column(
|
||||||
|
_enum_column(ConfidenceLevel, name="vote_measure_link_confidence")
|
||||||
|
)
|
||||||
|
notes: Mapped[str | None]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_measure_links")
|
||||||
|
measure: Mapped[Bill] = relationship("Bill", back_populates="vote_measure_links")
|
||||||
|
|
||||||
|
|
||||||
|
class VoteTextTarget(DataScienceDevTableBase):
|
||||||
|
"""Official text target, if any, resolved for one classified vote."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_text_target"
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"ix_vote_text_target_voted_text_version_id",
|
||||||
|
"voted_text_version_id",
|
||||||
|
postgresql_where=text("voted_text_version_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
text_target_type: Mapped[TextTargetType] = mapped_column(
|
||||||
|
_enum_column(TextTargetType, name="vote_text_target_type")
|
||||||
|
)
|
||||||
|
voted_text_version_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.bill_text.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
resulting_text_version_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.bill_text.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
related_amendment_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.amendment.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
text_target_basis: Mapped[TextTargetBasis] = mapped_column(
|
||||||
|
_enum_column(TextTargetBasis, name="vote_text_target_basis")
|
||||||
|
)
|
||||||
|
text_resolution_method: Mapped[TextResolutionMethod] = mapped_column(
|
||||||
|
_enum_column(TextResolutionMethod, name="vote_text_resolution_method")
|
||||||
|
)
|
||||||
|
text_resolution_confidence_reason: Mapped[str | None]
|
||||||
|
confidence: Mapped[ConfidenceLevel] = mapped_column(
|
||||||
|
_enum_column(ConfidenceLevel, name="vote_text_target_confidence")
|
||||||
|
)
|
||||||
|
notes: Mapped[str | None]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="text_target")
|
||||||
|
voted_text_version: Mapped[BillText | None] = relationship(
|
||||||
|
"BillText",
|
||||||
|
foreign_keys=[voted_text_version_id],
|
||||||
|
)
|
||||||
|
resulting_text_version: Mapped[BillText | None] = relationship(
|
||||||
|
"BillText",
|
||||||
|
foreign_keys=[resulting_text_version_id],
|
||||||
|
)
|
||||||
|
related_amendment: Mapped[Amendment | None] = relationship("Amendment")
|
||||||
|
|
||||||
|
|
||||||
|
class VotePositionMeaning(DataScienceDevTableBase):
|
||||||
|
"""Meaning of Yea/Nay/Present positions for one classified vote."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_position_meaning"
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
yea_effect: Mapped[VoteEffect] = mapped_column(
|
||||||
|
_enum_column(VoteEffect, name="vote_yea_effect")
|
||||||
|
)
|
||||||
|
nay_effect: Mapped[VoteEffect] = mapped_column(
|
||||||
|
_enum_column(VoteEffect, name="vote_nay_effect")
|
||||||
|
)
|
||||||
|
present_effect: Mapped[VoteEffect] = mapped_column(
|
||||||
|
_enum_column(VoteEffect, name="vote_present_effect")
|
||||||
|
)
|
||||||
|
polarity_confidence: Mapped[ConfidenceLevel] = mapped_column(
|
||||||
|
_enum_column(ConfidenceLevel, name="vote_polarity_confidence")
|
||||||
|
)
|
||||||
|
polarity_method: Mapped[str]
|
||||||
|
notes: Mapped[str | None]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="position_meaning")
|
||||||
|
|
||||||
|
|
||||||
|
class VoteContextAudit(DataScienceDevTableBase):
|
||||||
|
"""Audit/event row for ambiguous or noteworthy vote-context decisions."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_context_audit"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_vote_context_audit_vote_id", "vote_id"),
|
||||||
|
Index("ix_vote_context_audit_severity_vote_id", "severity", "vote_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(ForeignKey("main.vote.id", ondelete="CASCADE"))
|
||||||
|
step: Mapped[str]
|
||||||
|
message: Mapped[str]
|
||||||
|
severity: Mapped[str]
|
||||||
|
source_path: Mapped[str | None]
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="context_audit_rows")
|
||||||
@@ -5,12 +5,13 @@ from __future__ import annotations
|
|||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Text
|
from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from pipelines.orm.data_science_dev.congress.context import ScoreRun
|
||||||
from pipelines.orm.data_science_dev.congress.vote import VoteRecord
|
from pipelines.orm.data_science_dev.congress.vote import VoteRecord
|
||||||
|
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ class Legislator(DataScienceDevTableBase):
|
|||||||
"""Members of Congress with identification and current term info."""
|
"""Members of Congress with identification and current term info."""
|
||||||
|
|
||||||
__tablename__ = "legislator"
|
__tablename__ = "legislator"
|
||||||
|
__table_args__ = (Index("ix_legislator_current_chamber", "current_chamber"),)
|
||||||
|
|
||||||
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
||||||
|
|
||||||
@@ -50,6 +52,11 @@ class Legislator(DataScienceDevTableBase):
|
|||||||
back_populates="legislator",
|
back_populates="legislator",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
scores: Mapped[list[LegislatorScore]] = relationship(
|
||||||
|
"LegislatorScore",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LegislatorSocialMedia(DataScienceDevTableBase):
|
class LegislatorSocialMedia(DataScienceDevTableBase):
|
||||||
@@ -66,3 +73,59 @@ class LegislatorSocialMedia(DataScienceDevTableBase):
|
|||||||
legislator: Mapped[Legislator] = relationship(
|
legislator: Mapped[Legislator] = relationship(
|
||||||
back_populates="social_media_accounts"
|
back_populates="social_media_accounts"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LegislatorScore(DataScienceDevTableBase):
|
||||||
|
"""Computed topic score for a legislator in one calendar year."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator_score"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
Index("ix_legislator_score_year_topic", "year", "topic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
legislator_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
score_run_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.score_run.id", ondelete="CASCADE"),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
year: Mapped[int]
|
||||||
|
topic: Mapped[str]
|
||||||
|
score: Mapped[float]
|
||||||
|
|
||||||
|
legislator: Mapped[Legislator] = relationship(back_populates="scores")
|
||||||
|
score_run: Mapped[ScoreRun | None] = relationship(
|
||||||
|
"ScoreRun",
|
||||||
|
back_populates="scores",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LegislatorScoreFake(DataScienceDevTableBase):
|
||||||
|
"""Computed topic score for a legislator in one calendar year."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator_score_fake"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_fake_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
Index("ix_legislator_score_fake_year_topic", "year", "topic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
legislator_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
year: Mapped[int]
|
||||||
|
topic: Mapped[str]
|
||||||
|
score: Mapped[float]
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Vote model - roll call votes in Congress."""
|
"""Vote models for raw roll-call data and member positions."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date
|
from datetime import date, datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
from sqlalchemy import DateTime, ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import (
|
from pipelines.orm.data_science_dev.base import (
|
||||||
@@ -14,9 +15,15 @@ from pipelines.orm.data_science_dev.base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pipelines.orm.data_science_dev.congress.bill import Bill
|
from pipelines.orm.data_science_dev.congress.context import (
|
||||||
|
VoteActionMatch,
|
||||||
|
VoteClassification,
|
||||||
|
VoteContextAudit,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.congress.legislator import Legislator
|
from pipelines.orm.data_science_dev.congress.legislator import Legislator
|
||||||
from pipelines.orm.data_science_dev.congress.vote import Vote
|
|
||||||
|
|
||||||
|
|
||||||
class VoteRecord(DataScienceDevBase):
|
class VoteRecord(DataScienceDevBase):
|
||||||
@@ -41,14 +48,26 @@ class VoteRecord(DataScienceDevBase):
|
|||||||
|
|
||||||
|
|
||||||
class Vote(DataScienceDevTableBase):
|
class Vote(DataScienceDevTableBase):
|
||||||
"""Roll call votes with counts and optional bill linkage."""
|
"""Raw roll call vote facts from House or Senate vote sources."""
|
||||||
|
|
||||||
__tablename__ = "vote"
|
__tablename__ = "vote"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session_number",
|
||||||
|
"roll_number",
|
||||||
|
name="uq_vote_congress_chamber_session_number_roll_number",
|
||||||
|
),
|
||||||
|
Index("ix_vote_date", "vote_date"),
|
||||||
|
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||||
|
)
|
||||||
|
|
||||||
congress: Mapped[int]
|
congress: Mapped[int]
|
||||||
chamber: Mapped[str]
|
chamber: Mapped[str]
|
||||||
session: Mapped[int]
|
session_year: Mapped[int]
|
||||||
number: Mapped[int]
|
session_number: Mapped[int]
|
||||||
|
roll_number: Mapped[int]
|
||||||
|
|
||||||
vote_type: Mapped[str | None]
|
vote_type: Mapped[str | None]
|
||||||
question: Mapped[str | None]
|
question: Mapped[str | None]
|
||||||
@@ -56,29 +75,57 @@ class Vote(DataScienceDevTableBase):
|
|||||||
result_text: Mapped[str | None]
|
result_text: Mapped[str | None]
|
||||||
|
|
||||||
vote_date: Mapped[date]
|
vote_date: Mapped[date]
|
||||||
|
vote_datetime: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
raw_vote_source_url: Mapped[str | None]
|
||||||
|
|
||||||
yea_count: Mapped[int | None]
|
yea_count: Mapped[int | None]
|
||||||
nay_count: Mapped[int | None]
|
nay_count: Mapped[int | None]
|
||||||
not_voting_count: Mapped[int | None]
|
not_voting_count: Mapped[int | None]
|
||||||
present_count: Mapped[int | None]
|
present_count: Mapped[int | None]
|
||||||
|
|
||||||
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
|
raw_bill_ref: Mapped[dict | None] = mapped_column(JSONB)
|
||||||
|
raw_amendment_ref: Mapped[dict | None] = mapped_column(JSONB)
|
||||||
|
raw_nomination_ref: Mapped[dict | None] = mapped_column(JSONB)
|
||||||
|
raw_treaty_ref: Mapped[dict | None] = mapped_column(JSONB)
|
||||||
|
raw_vote_source_artifact_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("main.source_artifact.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
|
||||||
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
|
|
||||||
vote_records: Mapped[list[VoteRecord]] = relationship(
|
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
back_populates="vote",
|
back_populates="vote",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
action_matches: Mapped[list[VoteActionMatch]] = relationship(
|
||||||
__table_args__ = (
|
"VoteActionMatch",
|
||||||
UniqueConstraint(
|
back_populates="vote",
|
||||||
"congress",
|
cascade="all, delete-orphan",
|
||||||
"chamber",
|
)
|
||||||
"session",
|
classification: Mapped[VoteClassification | None] = relationship(
|
||||||
"number",
|
"VoteClassification",
|
||||||
name="uq_vote_congress_chamber_session_number",
|
back_populates="vote",
|
||||||
),
|
cascade="all, delete-orphan",
|
||||||
Index("ix_vote_date", "vote_date"),
|
uselist=False,
|
||||||
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
)
|
||||||
|
vote_measure_links: Mapped[list[VoteMeasureLink]] = relationship(
|
||||||
|
"VoteMeasureLink",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
text_target: Mapped[VoteTextTarget | None] = relationship(
|
||||||
|
"VoteTextTarget",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
uselist=False,
|
||||||
|
)
|
||||||
|
position_meaning: Mapped[VotePositionMeaning | None] = relationship(
|
||||||
|
"VotePositionMeaning",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
uselist=False,
|
||||||
|
)
|
||||||
|
context_audit_rows: Mapped[list[VoteContextAudit]] = relationship(
|
||||||
|
"VoteContextAudit",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,15 +2,81 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Amendment,
|
||||||
|
AmendmentAction,
|
||||||
|
AmendmentActionRecordedVote,
|
||||||
|
Bill,
|
||||||
|
BillAction,
|
||||||
|
BillActionRecordedVote,
|
||||||
|
BillRelation,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
ClassificationMethod,
|
||||||
|
ConfidenceLevel,
|
||||||
|
IngestRun,
|
||||||
|
Legislator,
|
||||||
|
LegislatorScore,
|
||||||
|
MeasureFunction,
|
||||||
|
MeasureSubtype,
|
||||||
|
ScoreRun,
|
||||||
|
SourceArtifact,
|
||||||
|
SubjectType,
|
||||||
|
TextResolutionMethod,
|
||||||
|
TextTargetBasis,
|
||||||
|
TextTargetType,
|
||||||
|
Vote,
|
||||||
|
VoteActionMatch,
|
||||||
|
VoteActionScope,
|
||||||
|
VoteClassification,
|
||||||
|
VoteContextAudit,
|
||||||
|
VoteEffect,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VoteMeasureRole,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteRecord,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
||||||
from pipelines.orm.data_science_dev.posts.tables import Posts
|
from pipelines.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"Amendment",
|
||||||
|
"AmendmentAction",
|
||||||
|
"AmendmentActionRecordedVote",
|
||||||
"Bill",
|
"Bill",
|
||||||
|
"BillAction",
|
||||||
|
"BillActionRecordedVote",
|
||||||
|
"BillRelation",
|
||||||
"BillText",
|
"BillText",
|
||||||
|
"BillTopic",
|
||||||
|
"BillTopicPosition",
|
||||||
|
"ClassificationMethod",
|
||||||
|
"ConfidenceLevel",
|
||||||
|
"IngestRun",
|
||||||
"Legislator",
|
"Legislator",
|
||||||
|
"LegislatorScore",
|
||||||
|
"MeasureFunction",
|
||||||
|
"MeasureSubtype",
|
||||||
"Posts",
|
"Posts",
|
||||||
|
"ScoreRun",
|
||||||
|
"SourceArtifact",
|
||||||
|
"SubjectType",
|
||||||
|
"TextResolutionMethod",
|
||||||
|
"TextTargetBasis",
|
||||||
|
"TextTargetType",
|
||||||
"Vote",
|
"Vote",
|
||||||
|
"VoteActionMatch",
|
||||||
|
"VoteActionScope",
|
||||||
|
"VoteClassification",
|
||||||
|
"VoteContextAudit",
|
||||||
|
"VoteEffect",
|
||||||
|
"VoteMeasureLink",
|
||||||
|
"VoteMeasureRole",
|
||||||
|
"VotePositionMeaning",
|
||||||
|
"VoteRelationship",
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
|
"VoteTextTarget",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,9 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
|
from pipelines.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
|
||||||
from pipelines.orm.data_science_dev.posts.tables import Posts
|
from pipelines.orm.data_science_dev.posts.tables import Posts, PostTopic
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FailedIngestion",
|
"FailedIngestion",
|
||||||
"Posts",
|
"Posts",
|
||||||
|
"PostTopic",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,195 @@
|
|||||||
|
"""Shared language filter constants for post sampling queries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
ENGLISH_LANGS = (
|
||||||
|
'["", "", ""]',
|
||||||
|
'[""]',
|
||||||
|
"[]",
|
||||||
|
'["", "eng"]',
|
||||||
|
'["eng", "", ""]',
|
||||||
|
'["eng", ""]',
|
||||||
|
'["eng"]',
|
||||||
|
'["eng", "aar"]',
|
||||||
|
'["eng", "abk", "afr"]',
|
||||||
|
'["eng", "afr"]',
|
||||||
|
'["eng", "afr", "abk"]',
|
||||||
|
'["eng", "afr", "anp"]',
|
||||||
|
'["eng", "afr", "ber"]',
|
||||||
|
'["eng", "afr", "dan"]',
|
||||||
|
'["eng", "afr", "deu"]',
|
||||||
|
'["eng", "afr", "est"]',
|
||||||
|
'["eng", "afr", "fra"]',
|
||||||
|
'["eng", "afr", "ind"]',
|
||||||
|
'["eng", "afr", "lat"]',
|
||||||
|
'["eng", "afr", "nld"]',
|
||||||
|
'["eng", "afr", "nor"]',
|
||||||
|
'["eng", "afr", "pol"]',
|
||||||
|
'["eng", "afr", "por"]',
|
||||||
|
'["eng", "afr", "ron"]',
|
||||||
|
'["eng", "afr", "slk"]',
|
||||||
|
'["eng", "afr", "spa"]',
|
||||||
|
'["eng", "afr", "tgl"]',
|
||||||
|
'["eng", "afr", "tuk"]',
|
||||||
|
'["eng", "afr", "tur"]',
|
||||||
|
'["eng", "afr", "ukr"]',
|
||||||
|
'["eng", "afr", "vol"]',
|
||||||
|
'["eng", "agq"]',
|
||||||
|
'["eng", "ain"]',
|
||||||
|
'["eng", "ain", "amh"]',
|
||||||
|
'["eng", "ain", "jpn"]',
|
||||||
|
'["eng", "aka"]',
|
||||||
|
'["eng", "amh"]',
|
||||||
|
'["eng", "amh", "afr"]',
|
||||||
|
'["eng", "amh", "ara"]',
|
||||||
|
'["eng", "amh", "fra"]',
|
||||||
|
'["eng", "anp"]',
|
||||||
|
'["eng", "anp", "hye"]',
|
||||||
|
'["eng", "anp", "sqi"]',
|
||||||
|
'["eng", "", "ara"]',
|
||||||
|
'["eng", "ara", ""]',
|
||||||
|
'["eng", "ara"]',
|
||||||
|
'["eng", "ara", "afr"]',
|
||||||
|
'["eng", "ara", "anp"]',
|
||||||
|
'["eng", "ara", "ars"]',
|
||||||
|
'["eng", "ara", "bul"]',
|
||||||
|
'["eng", "ara", "cat"]',
|
||||||
|
'["eng", "ara", "deu"]',
|
||||||
|
'["eng", "ara", "ell"]',
|
||||||
|
'["eng", "ara", "fas"]',
|
||||||
|
'["eng", "ara", "fra"]',
|
||||||
|
'["eng", "ara", "heb"]',
|
||||||
|
'["eng", "ara", "hin"]',
|
||||||
|
'["eng", "ara", "ind"]',
|
||||||
|
'["eng", "ara", "ita"]',
|
||||||
|
'["eng", "ara", "jpn"]',
|
||||||
|
'["eng", "ara", "kas"]',
|
||||||
|
'["eng", "ara", "kor"]',
|
||||||
|
'["eng", "ara", "nob"]',
|
||||||
|
'["eng", "ara", "nor"]',
|
||||||
|
'["eng", "ara", "rus"]',
|
||||||
|
'["eng", "ara", "spa"]',
|
||||||
|
'["eng", "ara", "swe"]',
|
||||||
|
'["eng", "ara", "tam"]',
|
||||||
|
'["eng", "ara", "tur"]',
|
||||||
|
'["eng", "ara", "urd"]',
|
||||||
|
'["eng", "ara", "zho"]',
|
||||||
|
'["eng", "arg"]',
|
||||||
|
'["eng", "arg", "amh"]',
|
||||||
|
'["eng", "arg", "aze"]',
|
||||||
|
'["eng", "ars"]',
|
||||||
|
'["eng", "ars", "ara"]',
|
||||||
|
'["eng", "asm"]',
|
||||||
|
'["eng", "ava", "sqi"]',
|
||||||
|
'["eng", "ave"]',
|
||||||
|
'["eng", "aze"]',
|
||||||
|
'["eng", "aze", "deu"]',
|
||||||
|
'["eng", "aze", "hye"]',
|
||||||
|
'["eng", "aze", "ita"]',
|
||||||
|
'["eng", "aze", "rus"]',
|
||||||
|
'["eng", "bam", ""]',
|
||||||
|
'["eng", "bel"]',
|
||||||
|
'["eng", "bel", "rus"]',
|
||||||
|
'["eng", "ben"]',
|
||||||
|
'["eng", "ben", "deu"]',
|
||||||
|
'["eng", "ben", "fra"]',
|
||||||
|
'["eng", "ben", "hin"]',
|
||||||
|
'["eng", "ben", "mya"]',
|
||||||
|
'["eng", "ber"]',
|
||||||
|
'["eng", "ber", "afr"]',
|
||||||
|
'["eng", "ber", "deu"]',
|
||||||
|
'["eng", "ber", "est"]',
|
||||||
|
'["eng", "ber", "hun"]',
|
||||||
|
'["eng", "ber", "isl"]',
|
||||||
|
'["eng", "ber", "jpn"]',
|
||||||
|
'["eng", "ber", "lat"]',
|
||||||
|
'["eng", "ber", "nor"]',
|
||||||
|
'["eng", "ber", "pol"]',
|
||||||
|
'["eng", "ber", "por"]',
|
||||||
|
'["eng", "ber", "ron"]',
|
||||||
|
'["eng", "ber", "run"]',
|
||||||
|
'["eng", "ber", "slk"]',
|
||||||
|
'["eng", "ber", "spa"]',
|
||||||
|
'["eng", "ber", "tgl"]',
|
||||||
|
'["eng", "ber", "tlh"]',
|
||||||
|
'["eng", "ber", "tuk"]',
|
||||||
|
'["eng", "bod"]',
|
||||||
|
'["eng", "bod", "nep"]',
|
||||||
|
'["eng", "bos", "hrv"]',
|
||||||
|
'["eng", "bos", "srp"]',
|
||||||
|
'["eng", "bul"]',
|
||||||
|
'["eng", "bul", "deu"]',
|
||||||
|
'["eng", "bul", "fra"]',
|
||||||
|
'["eng", "bul", "jpn"]',
|
||||||
|
'["eng", "bul", "mkd"]',
|
||||||
|
'["eng", "bul", "mri"]',
|
||||||
|
'["eng", "bul", "nld"]',
|
||||||
|
'["eng", "bul", "rus"]',
|
||||||
|
'["eng", "bul", "srp"]',
|
||||||
|
'["eng", "cat"]',
|
||||||
|
'["eng", "cat", "fra"]',
|
||||||
|
'["eng", "cat", "ind"]',
|
||||||
|
'["eng", "cat", "isl"]',
|
||||||
|
'["eng", "cat", "jpn"]',
|
||||||
|
'["eng", "cat", "nld"]',
|
||||||
|
'["eng", "cat", "spa"]',
|
||||||
|
'["eng", "ces"]',
|
||||||
|
'["eng", "ces", "deu"]',
|
||||||
|
'["eng", "ces", "ell"]',
|
||||||
|
'["eng", "ces", "haw"]',
|
||||||
|
'["eng", "ces", "ind"]',
|
||||||
|
'["eng", "ces", "ita"]',
|
||||||
|
'["eng", "ces", "jpn"]',
|
||||||
|
'["eng", "ces", "por"]',
|
||||||
|
'["eng", "ces", "rus"]',
|
||||||
|
'["eng", "ces", "slk"]',
|
||||||
|
'["eng", "ces", "spa"]',
|
||||||
|
'["eng", "ces", "tuk"]',
|
||||||
|
'["eng", "cha"]',
|
||||||
|
'["eng", "chr"]',
|
||||||
|
'["eng", "chr", "ara"]',
|
||||||
|
'["eng", "chr", "deu"]',
|
||||||
|
'["eng", "chr", "ell"]',
|
||||||
|
'["eng", "chr", "fil"]',
|
||||||
|
'["eng", "chr", "isl"]',
|
||||||
|
'["eng", "chr", "kor"]',
|
||||||
|
'["eng", "chr", "rus"]',
|
||||||
|
'["eng", "chr", "spa"]',
|
||||||
|
'["eng", "chr", "zho"]',
|
||||||
|
'["eng", "chu", "oci"]',
|
||||||
|
'["eng", "cor"]',
|
||||||
|
'["eng", "", "cos"]',
|
||||||
|
'["eng", "cos"]',
|
||||||
|
'["eng", "cym"]',
|
||||||
|
'["eng", "cym", "deu"]',
|
||||||
|
'["eng", "cym", "fra"]',
|
||||||
|
'["eng", "cym", "jpn"]',
|
||||||
|
'["eng", "cym", "spa"]',
|
||||||
|
'["eng", "cym", "zho"]',
|
||||||
|
'["eng", "dan"]',
|
||||||
|
'["eng", "dan", "ber"]',
|
||||||
|
'["eng", "dan", "deu"]',
|
||||||
|
'["eng", "dan", "ell"]',
|
||||||
|
'["eng", "dan", "est"]',
|
||||||
|
'["eng", "dan", "fas"]',
|
||||||
|
'["eng", "dan", "fin"]',
|
||||||
|
'["eng", "dan", "fra"]',
|
||||||
|
'["eng", "dan", "gle"]',
|
||||||
|
'["eng", "dan", "hun"]',
|
||||||
|
'["eng", "dan", "isl"]',
|
||||||
|
'["eng", "dan", "ita"]',
|
||||||
|
'["eng", "dan", "jpn"]',
|
||||||
|
'["eng", "dan", "lat"]',
|
||||||
|
'["eng", "dan", "nld"]',
|
||||||
|
'["eng", "dan", "nob"]',
|
||||||
|
'["eng", "dan", "nor"]',
|
||||||
|
'["eng", "dan", "por"]',
|
||||||
|
'["eng", "dan", "rus"]',
|
||||||
|
'["eng", "dan", "slk"]',
|
||||||
|
'["eng", "dan", "spa"]',
|
||||||
|
'["eng", "dan", "swe"]',
|
||||||
|
'["eng", "dan", "tuk"]',
|
||||||
|
'["eng", "dan", "zho"]',
|
||||||
|
'["eng", "deu", ""]',
|
||||||
|
'["eng", "deu"]',
|
||||||
|
)
|
||||||
@@ -1,13 +1,36 @@
|
|||||||
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
|
"""Posts parent table and PostTopic table for the data_science_dev database."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import DataScienceDevBase
|
from pipelines.orm.data_science_dev.base import (
|
||||||
|
DataScienceDevBase,
|
||||||
|
DataScienceDevTableBase,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
|
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Index, SmallInteger
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
class Posts(PostsColumns, DataScienceDevBase):
|
class Posts(PostsColumns, DataScienceDevBase):
|
||||||
"""Parent partitioned table for posts, partitioned by week on `date`."""
|
"""Parent partitioned table for posts, partitioned by week on `date`."""
|
||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
|
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
|
||||||
|
|
||||||
|
|
||||||
|
class PostTopic(DataScienceDevTableBase):
|
||||||
|
"""Stores BERTopic topic assignments for posts.
|
||||||
|
|
||||||
|
post_id references main.posts but without a FK constraint
|
||||||
|
since posts is a partitioned table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "post_topic"
|
||||||
|
__table_args__ = (Index("ix_post_topic_post_id", "post_id"),)
|
||||||
|
|
||||||
|
post_id: Mapped[int] = mapped_column(BigInteger)
|
||||||
|
topic_id: Mapped[int] = mapped_column(SmallInteger)
|
||||||
|
topic_label: Mapped[str | None]
|
||||||
|
model_version: Mapped[str | None]
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
"""Thing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
modes = Literal["normal", "early_error"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutorResults[R]:
|
||||||
|
"""Dataclass to store the results and exceptions of the parallel execution."""
|
||||||
|
|
||||||
|
results: list[R]
|
||||||
|
exceptions: list[BaseException]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return a string representation of the object."""
|
||||||
|
return f"results={self.results} exceptions={self.exceptions}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parallelize_base[R](
|
||||||
|
executor_type: type[ThreadPoolExecutor | ProcessPoolExecutor],
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None,
|
||||||
|
progress_tracker: int | None,
|
||||||
|
mode: modes,
|
||||||
|
) -> ExecutorResults:
|
||||||
|
total_work = len(kwargs_list)
|
||||||
|
|
||||||
|
with executor_type(max_workers=max_workers) as executor:
|
||||||
|
futures = [executor.submit(func, **kwarg) for kwarg in kwargs_list]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
exceptions = []
|
||||||
|
for index, future in enumerate(futures, 1):
|
||||||
|
if exception := future.exception():
|
||||||
|
logger.error(f"{future} raised {exception.__class__.__name__}")
|
||||||
|
exceptions.append(exception)
|
||||||
|
if mode == "early_error":
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
raise exception
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(future.result())
|
||||||
|
|
||||||
|
if progress_tracker and index % progress_tracker == 0:
|
||||||
|
logger.info(f"Progress: {index}/{total_work}")
|
||||||
|
|
||||||
|
return ExecutorResults(results, exceptions)
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_thread[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None = None,
|
||||||
|
progress_tracker: int | None = None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in threads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in threads.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
return _parallelize_base(
|
||||||
|
executor_type=ThreadPoolExecutor,
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_process[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None = None,
|
||||||
|
progress_tracker: int | None = None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in process.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 4.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
if max_workers and max_workers > cpu_count():
|
||||||
|
error = f"max_workers must be less than or equal to {cpu_count()}"
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
return process_executor_unchecked(
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_executor_unchecked[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None,
|
||||||
|
progress_tracker: int | None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in parallel.
|
||||||
|
|
||||||
|
Note: this function does not check if the number of workers is greater than the number of CPUs.
|
||||||
|
This can cause the system to become unresponsive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in parallel.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
return _parallelize_base(
|
||||||
|
executor_type=ProcessPoolExecutor,
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
|
||||||
#
|
|
||||||
# Build:
|
|
||||||
# docker build -f pipelines/pipelines/tools/Dockerfile.finetune -t bill-finetune .
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
|
||||||
# -v $(pwd)/output:/workspace/output \
|
|
||||||
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
|
|
||||||
# -v /zfs/models/hf:/models \
|
|
||||||
# bill-finetune \
|
|
||||||
# --dataset /workspace/dataset.jsonl \
|
|
||||||
# --output-dir /workspace/output/qwen-bill-summarizer
|
|
||||||
|
|
||||||
FROM ghcr.io/unslothai/unsloth:latest
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir typer rouge-score
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
COPY pipelines/tools/__init__.py pipelines/tools/__init__.py
|
|
||||||
COPY pipelines/tools/finetune.py pipelines/tools/finetune.py
|
|
||||||
COPY pipelines/tools/summarization_eval.py pipelines/tools/summarization_eval.py
|
|
||||||
COPY summarization_prompts.toml config/prompts/summarization_prompts.toml
|
|
||||||
COPY config.toml pipelines/tools/config.toml
|
|
||||||
|
|
||||||
ENTRYPOINT ["python", "-m", "pipelines.tools.finetune"]
|
|
||||||
@@ -25,8 +25,6 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
from .summarization_eval import make_compute_metrics
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -189,9 +187,6 @@ def main(
|
|||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
seed=42,
|
seed=42,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
metric_for_best_model="eval_composite",
|
|
||||||
greater_is_better=True,
|
|
||||||
predict_with_generate=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
@@ -202,7 +197,6 @@ def main(
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
max_seq_length=config.training.max_seq_length,
|
max_seq_length=config.training.max_seq_length,
|
||||||
packing=True,
|
packing=True,
|
||||||
compute_metrics=make_compute_metrics(tokenizer),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,426 +0,0 @@
|
|||||||
"""Summarization evaluation for Congressional bill summaries.
|
|
||||||
|
|
||||||
Three use cases from one module:
|
|
||||||
|
|
||||||
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
|
|
||||||
from summarization_eval import filter_dataset
|
|
||||||
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
|
|
||||||
|
|
||||||
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
|
|
||||||
from summarization_eval import make_compute_metrics
|
|
||||||
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
|
|
||||||
|
|
||||||
3. Inference eval — score a finished model against held-out references:
|
|
||||||
from summarization_eval import evaluate_file
|
|
||||||
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from rouge_score import rouge_scorer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Constants
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
SECTION_HEADERS = [
|
|
||||||
"OPERATIVE ACTIONS",
|
|
||||||
"AFFECTED POPULATIONS",
|
|
||||||
"MECHANISMS",
|
|
||||||
"POLICY THREADS",
|
|
||||||
"SYMBOLIC/PROCEDURAL ONLY",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
|
|
||||||
ROUGE_WEIGHTS = {
|
|
||||||
"rouge1": 0.2,
|
|
||||||
"rouge2": 0.4,
|
|
||||||
"rougeL": 0.4,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Composite score floor below which a training example is considered low quality
|
|
||||||
FILTER_THRESHOLD = 0.25
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Core data structures
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SummaryScore:
|
|
||||||
"""Scores for a single (prediction, reference) pair."""
|
|
||||||
|
|
||||||
rouge1: float
|
|
||||||
rouge2: float
|
|
||||||
rougeL: float
|
|
||||||
composite: float
|
|
||||||
has_all_sections: bool # True = all 5 headers present
|
|
||||||
missing_sections: list[str]
|
|
||||||
structural_fail: bool # True = one or more headers missing (hard guardrail)
|
|
||||||
|
|
||||||
def as_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"rouge1": self.rouge1,
|
|
||||||
"rouge2": self.rouge2,
|
|
||||||
"rougeL": self.rougeL,
|
|
||||||
"composite": self.composite,
|
|
||||||
"has_all_sections": self.has_all_sections,
|
|
||||||
"missing_sections": self.missing_sections,
|
|
||||||
"structural_fail": self.structural_fail,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BatchResult:
|
|
||||||
"""Aggregate results over a batch of summaries."""
|
|
||||||
|
|
||||||
n_total: int
|
|
||||||
n_structural_fail: int
|
|
||||||
n_scored: int # excludes structural failures
|
|
||||||
rouge1_mean: float
|
|
||||||
rouge2_mean: float
|
|
||||||
rougeL_mean: float
|
|
||||||
composite_mean: float
|
|
||||||
scores: list[SummaryScore]
|
|
||||||
|
|
||||||
def as_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"n_total": self.n_total,
|
|
||||||
"n_structural_fail": self.n_structural_fail,
|
|
||||||
"n_scored": self.n_scored,
|
|
||||||
"rouge1_mean": self.rouge1_mean,
|
|
||||||
"rouge2_mean": self.rouge2_mean,
|
|
||||||
"rougeL_mean": self.rougeL_mean,
|
|
||||||
"composite_mean": self.composite_mean,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Core scoring
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
|
||||||
|
|
||||||
|
|
||||||
def check_sections(text: str) -> tuple[bool, list[str]]:
|
|
||||||
"""Return (all_present, missing_headers) for the 5 required section headers."""
|
|
||||||
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
|
|
||||||
return len(missing) == 0, missing
|
|
||||||
|
|
||||||
|
|
||||||
def score_pair(prediction: str, reference: str) -> SummaryScore:
|
|
||||||
"""Score a single (prediction, reference) pair.
|
|
||||||
|
|
||||||
If the prediction is missing any section header, structural_fail is True
|
|
||||||
and ROUGE scores are still computed (so you can inspect quality even on
|
|
||||||
structural failures) but the example should be treated as a guardrail failure.
|
|
||||||
"""
|
|
||||||
has_all, missing = check_sections(prediction)
|
|
||||||
|
|
||||||
rouge = _scorer.score(reference, prediction)
|
|
||||||
r1 = rouge["rouge1"].fmeasure
|
|
||||||
r2 = rouge["rouge2"].fmeasure
|
|
||||||
rl = rouge["rougeL"].fmeasure
|
|
||||||
composite = (
|
|
||||||
ROUGE_WEIGHTS["rouge1"] * r1
|
|
||||||
+ ROUGE_WEIGHTS["rouge2"] * r2
|
|
||||||
+ ROUGE_WEIGHTS["rougeL"] * rl
|
|
||||||
)
|
|
||||||
|
|
||||||
return SummaryScore(
|
|
||||||
rouge1=r1,
|
|
||||||
rouge2=r2,
|
|
||||||
rougeL=rl,
|
|
||||||
composite=composite,
|
|
||||||
has_all_sections=has_all,
|
|
||||||
missing_sections=missing,
|
|
||||||
structural_fail=not has_all,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
|
|
||||||
"""Score a list of (prediction, reference) pairs and return aggregate results.
|
|
||||||
|
|
||||||
Structural failures are counted separately and excluded from ROUGE means
|
|
||||||
so a batch with broken formatting doesn't drag down the score unfairly.
|
|
||||||
"""
|
|
||||||
scores = [score_pair(pred, ref) for pred, ref in pairs]
|
|
||||||
|
|
||||||
structural_fails = [s for s in scores if s.structural_fail]
|
|
||||||
valid = [s for s in scores if not s.structural_fail]
|
|
||||||
|
|
||||||
if valid:
|
|
||||||
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
|
|
||||||
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
|
|
||||||
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
|
|
||||||
composite_mean = float(np.mean([s.composite for s in valid]))
|
|
||||||
else:
|
|
||||||
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
|
|
||||||
|
|
||||||
return BatchResult(
|
|
||||||
n_total=len(scores),
|
|
||||||
n_structural_fail=len(structural_fails),
|
|
||||||
n_scored=len(valid),
|
|
||||||
rouge1_mean=rouge1_mean,
|
|
||||||
rouge2_mean=rouge2_mean,
|
|
||||||
rougeL_mean=rougeL_mean,
|
|
||||||
composite_mean=composite_mean,
|
|
||||||
scores=scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Use case 1: Data filtering
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def filter_dataset(
|
|
||||||
input_path: Path | str,
|
|
||||||
output_path: Path | str,
|
|
||||||
*,
|
|
||||||
threshold: float = FILTER_THRESHOLD,
|
|
||||||
) -> tuple[int, int]:
|
|
||||||
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
|
|
||||||
|
|
||||||
Each line must be a ChatML messages dict:
|
|
||||||
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
|
|
||||||
|
|
||||||
The assistant turn is the prediction. The reference is the same assistant
|
|
||||||
turn — filtering here uses composite score as a self-consistency check
|
|
||||||
against the threshold, and drops structural failures unconditionally.
|
|
||||||
|
|
||||||
In practice you'd call this after joining requests + GPT completions
|
|
||||||
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
|
|
||||||
or suspiciously short/low quality.
|
|
||||||
|
|
||||||
Returns (kept, dropped).
|
|
||||||
"""
|
|
||||||
input_path = Path(input_path)
|
|
||||||
output_path = Path(output_path)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
kept = 0
|
|
||||||
dropped = 0
|
|
||||||
|
|
||||||
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
|
|
||||||
for line_num, raw_line in enumerate(fin, 1):
|
|
||||||
stripped = raw_line.strip()
|
|
||||||
if not stripped:
|
|
||||||
continue
|
|
||||||
|
|
||||||
example = json.loads(stripped)
|
|
||||||
messages = example.get("messages", [])
|
|
||||||
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
|
|
||||||
|
|
||||||
if not assistant_turns:
|
|
||||||
logger.warning("Line %d: no assistant turn, dropping", line_num)
|
|
||||||
dropped += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
prediction = assistant_turns[-1].get("content", "")
|
|
||||||
|
|
||||||
# Guardrail: drop if any section header missing
|
|
||||||
has_all, missing = check_sections(prediction)
|
|
||||||
if not has_all:
|
|
||||||
logger.warning(
|
|
||||||
"Line %d: structural fail (missing: %s), dropping",
|
|
||||||
line_num,
|
|
||||||
", ".join(missing),
|
|
||||||
)
|
|
||||||
dropped += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Quality floor: score against itself isn't meaningful for filtering —
|
|
||||||
# instead just check composite score of prediction vs a simple
|
|
||||||
# word-count proxy. For filtering GPT outputs, structural check
|
|
||||||
# + a minimum word count is usually sufficient.
|
|
||||||
word_count = len(prediction.split())
|
|
||||||
if word_count < 80:
|
|
||||||
logger.warning(
|
|
||||||
"Line %d: too short (%d words), dropping", line_num, word_count
|
|
||||||
)
|
|
||||||
dropped += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
|
|
||||||
kept += 1
|
|
||||||
|
|
||||||
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
|
|
||||||
return kept, dropped
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Use case 2: compute_metrics hook for SFTTrainer
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
|
|
||||||
"""Return a compute_metrics function compatible with HuggingFace Trainer.
|
|
||||||
|
|
||||||
Usage in finetune.py:
|
|
||||||
from summarization_eval import make_compute_metrics
|
|
||||||
trainer = SFTTrainer(
|
|
||||||
...
|
|
||||||
compute_metrics=make_compute_metrics(tokenizer),
|
|
||||||
)
|
|
||||||
|
|
||||||
Note: EvalPrediction.predictions are logits (or token ids if
|
|
||||||
include_inputs_for_metrics is False). This function handles both.
|
|
||||||
For SFTTrainer with packing=True, you may need to set
|
|
||||||
predict_with_generate=True in TrainingArguments to get decoded text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
|
|
||||||
predictions, labels = eval_pred
|
|
||||||
|
|
||||||
# If predictions are logits, take argmax
|
|
||||||
if predictions.ndim == 3:
|
|
||||||
predictions = np.argmax(predictions, axis=-1)
|
|
||||||
|
|
||||||
# Mask out -100 padding in labels
|
|
||||||
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
|
|
||||||
|
|
||||||
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
|
||||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
|
|
||||||
pairs = list(zip(decoded_preds, decoded_labels))
|
|
||||||
result = score_batch(pairs)
|
|
||||||
|
|
||||||
metrics = {
|
|
||||||
"eval_rouge1": result.rouge1_mean,
|
|
||||||
"eval_rouge2": result.rouge2_mean,
|
|
||||||
"eval_rougeL": result.rougeL_mean,
|
|
||||||
"eval_composite": result.composite_mean,
|
|
||||||
"eval_structural_fail_rate": (
|
|
||||||
result.n_structural_fail / result.n_total if result.n_total else 0.0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
logger.info(
|
|
||||||
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
|
|
||||||
metrics["eval_composite"],
|
|
||||||
metrics["eval_rouge1"],
|
|
||||||
metrics["eval_rouge2"],
|
|
||||||
metrics["eval_rougeL"],
|
|
||||||
result.n_structural_fail,
|
|
||||||
result.n_total,
|
|
||||||
)
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
return compute_metrics
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Use case 3: Inference eval against held-out references
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def evaluate_file(
|
|
||||||
predictions_path: Path | str,
|
|
||||||
references_path: Path | str,
|
|
||||||
output_path: Path | str | None = None,
|
|
||||||
) -> BatchResult:
|
|
||||||
"""Score a predictions JSONL against a references JSONL.
|
|
||||||
|
|
||||||
Both files should be line-matched: line N of predictions corresponds
|
|
||||||
to line N of references. Each line should be a plain JSON object with
|
|
||||||
a "text" or "content" key, or a ChatML messages dict.
|
|
||||||
|
|
||||||
If output_path is provided, writes per-example scores as JSONL.
|
|
||||||
"""
|
|
||||||
predictions_path = Path(predictions_path)
|
|
||||||
references_path = Path(references_path)
|
|
||||||
|
|
||||||
def extract_text(line: str) -> str:
|
|
||||||
obj = json.loads(line)
|
|
||||||
# Plain text field
|
|
||||||
if "text" in obj:
|
|
||||||
return obj["text"]
|
|
||||||
if "content" in obj:
|
|
||||||
return obj["content"]
|
|
||||||
# ChatML messages — take last assistant turn
|
|
||||||
messages = obj.get("messages", [])
|
|
||||||
for m in reversed(messages):
|
|
||||||
if m.get("role") == "assistant":
|
|
||||||
return m.get("content", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
|
|
||||||
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
|
|
||||||
|
|
||||||
if len(preds) != len(refs):
|
|
||||||
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
result = score_batch(list(zip(preds, refs)))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Inference eval: n=%d structural_fails=%d composite=%.4f "
|
|
||||||
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
|
|
||||||
result.n_total,
|
|
||||||
result.n_structural_fail,
|
|
||||||
result.composite_mean,
|
|
||||||
result.rouge1_mean,
|
|
||||||
result.rouge2_mean,
|
|
||||||
result.rougeL_mean,
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_path is not None:
|
|
||||||
output_path = Path(output_path)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with output_path.open("w", encoding="utf-8") as fout:
|
|
||||||
for score in result.scores:
|
|
||||||
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
|
|
||||||
summary_path = output_path.with_suffix(".summary.json")
|
|
||||||
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
|
|
||||||
logger.info("Wrote per-example scores to %s", output_path)
|
|
||||||
logger.info("Wrote summary to %s", summary_path)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CLI — quick sanity check / standalone use
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _cli() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
|
|
||||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
||||||
|
|
||||||
# filter subcommand
|
|
||||||
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
|
|
||||||
fp.add_argument("--input", required=True, type=Path)
|
|
||||||
fp.add_argument("--output", required=True, type=Path)
|
|
||||||
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
|
|
||||||
|
|
||||||
# eval subcommand
|
|
||||||
ep = subparsers.add_parser("eval", help="Score predictions against references")
|
|
||||||
ep.add_argument("--predictions", required=True, type=Path)
|
|
||||||
ep.add_argument("--references", required=True, type=Path)
|
|
||||||
ep.add_argument("--output", type=Path, default=None)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
|
|
||||||
|
|
||||||
if args.command == "filter":
|
|
||||||
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
|
|
||||||
print(f"Kept: {kept} Dropped: {dropped}")
|
|
||||||
|
|
||||||
elif args.command == "eval":
|
|
||||||
result = evaluate_file(args.predictions, args.references, args.output)
|
|
||||||
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
|
|
||||||
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
|
|
||||||
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
|
|
||||||
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
|
|
||||||
print(f" Composite: {result.composite_mean:.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
_cli()
|
|
||||||
Reference in New Issue
Block a user