Compare commits
12 Commits
matt_ds
...
feature/v1
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a4f71f471 | |||
| c8adf6914e | |||
| 87a7f5312f | |||
| 51d6240690 | |||
| 1426b797e5 | |||
| 4b768049c0 | |||
| 674edafe94 | |||
| be4b473a3c | |||
| e5ba089479 | |||
| 0c99a63347 | |||
| 07cd231609 | |||
| c8b61fc3c0 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -110,7 +110,7 @@ ipython_config.py
|
|||||||
|
|
||||||
# pdm
|
# pdm
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-pipelines.
|
||||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
# pdm.lock
|
# pdm.lock
|
||||||
# pdm.toml
|
# pdm.toml
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,60 @@
|
|||||||
|
"""adding FailedIngestion.
|
||||||
|
|
||||||
|
Revision ID: 2f43120e3ffc
|
||||||
|
Revises: f99be864fe69
|
||||||
|
Create Date: 2026-03-24 23:46:17.277897
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2f43120e3ffc"
|
||||||
|
down_revision: str | None = "f99be864fe69"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"failed_ingestion",
|
||||||
|
sa.Column("raw_line", sa.Text(), nullable=False),
|
||||||
|
sa.Column("error", sa.Text(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("failed_ingestion", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,79 @@
|
|||||||
|
"""Attach all partition tables to the posts parent table.
|
||||||
|
|
||||||
|
Alembic autogenerate creates partition tables as standalone tables but does not
|
||||||
|
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
|
||||||
|
route inserts to the correct partition.
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 605b1794838f
|
||||||
|
Create Date: 2026-03-25 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
from pipelines.orm.data_science_dev.posts.partitions import (
|
||||||
|
PARTITION_END_YEAR,
|
||||||
|
PARTITION_START_YEAR,
|
||||||
|
iso_weeks_in_year,
|
||||||
|
week_bounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: str | None = "605b1794838f"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
ALREADY_ATTACHED_QUERY = text("""
|
||||||
|
SELECT inhrelid::regclass::text
|
||||||
|
FROM pg_inherits
|
||||||
|
WHERE inhparent = :parent::regclass
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Attach all weekly partition tables to the posts parent table."""
|
||||||
|
connection = op.get_bind()
|
||||||
|
already_attached = {
|
||||||
|
row[0]
|
||||||
|
for row in connection.execute(
|
||||||
|
ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
qualified_name = f"{schema}.{table_name}"
|
||||||
|
if qualified_name in already_attached:
|
||||||
|
continue
|
||||||
|
start, end = week_bounds(year, week)
|
||||||
|
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {schema}.posts "
|
||||||
|
f"ATTACH PARTITION {qualified_name} "
|
||||||
|
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Detach all weekly partition tables from the posts parent table."""
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,229 @@
|
|||||||
|
"""adding congress data.
|
||||||
|
|
||||||
|
Revision ID: 83bfc8af92d8
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2026-03-27 10:43:02.324510
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "83bfc8af92d8"
|
||||||
|
down_revision: str | None = "a1b2c3d4e5f6"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"bill",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("bill_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("title", sa.String(), nullable=True),
|
||||||
|
sa.Column("title_short", sa.String(), nullable=True),
|
||||||
|
sa.Column("official_title", sa.String(), nullable=True),
|
||||||
|
sa.Column("status", sa.String(), nullable=True),
|
||||||
|
sa.Column("status_at", sa.Date(), nullable=True),
|
||||||
|
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("subjects_top_term", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_congress", "bill", ["congress"], unique=False, schema=schema
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator",
|
||||||
|
sa.Column("bioguide_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("thomas_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("lis_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("govtrack_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("opensecrets_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("fec_ids", sa.String(), nullable=True),
|
||||||
|
sa.Column("first_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("last_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("official_full_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("nickname", sa.String(), nullable=True),
|
||||||
|
sa.Column("birthday", sa.Date(), nullable=True),
|
||||||
|
sa.Column("gender", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_party", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_state", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_district", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("current_chamber", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bioguide_id"),
|
||||||
|
"legislator",
|
||||||
|
["bioguide_id"],
|
||||||
|
unique=True,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("version_code", sa.String(), nullable=False),
|
||||||
|
sa.Column("version_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("text_content", sa.String(), nullable=True),
|
||||||
|
sa.Column("date", sa.Date(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_text_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("session", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("vote_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("question", sa.String(), nullable=True),
|
||||||
|
sa.Column("result", sa.String(), nullable=True),
|
||||||
|
sa.Column("result_text", sa.String(), nullable=True),
|
||||||
|
sa.Column("vote_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("yea_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("nay_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("not_voting_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("present_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session",
|
||||||
|
"number",
|
||||||
|
name="uq_vote_congress_chamber_session_number",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_vote_congress_chamber",
|
||||||
|
"vote",
|
||||||
|
["congress", "chamber"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
|
||||||
|
op.create_table(
|
||||||
|
"vote_record",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("position", sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_vote_record_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"],
|
||||||
|
[f"{schema}.vote.id"],
|
||||||
|
name=op.f("fk_vote_record_vote_id_vote"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint(
|
||||||
|
"vote_id", "legislator_id", name=op.f("pk_vote_record")
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("vote_record", schema=schema)
|
||||||
|
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
|
||||||
|
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
|
||||||
|
op.drop_table("vote", schema=schema)
|
||||||
|
op.drop_table("bill_text", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_table("legislator", schema=schema)
|
||||||
|
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
|
||||||
|
op.drop_table("bill", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
"""adding LegislatorSocialMedia.
|
||||||
|
|
||||||
|
Revision ID: 5cd7eee3549d
|
||||||
|
Revises: 83bfc8af92d8
|
||||||
|
Create Date: 2026-03-29 11:53:44.224799
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "5cd7eee3549d"
|
||||||
|
down_revision: str | None = "83bfc8af92d8"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_social_media",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("platform", sa.String(), nullable=False),
|
||||||
|
sa.Column("account_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("url", sa.String(), nullable=True),
|
||||||
|
sa.Column("source", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("legislator_social_media", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,245 @@
|
|||||||
|
"""adding LegislatorScore and BillTopic.
|
||||||
|
|
||||||
|
Revision ID: ef4bc5411176
|
||||||
|
Revises: 5cd7eee3549d
|
||||||
|
Create Date: 2026-04-21 11:35:18.977213
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "ef4bc5411176"
|
||||||
|
down_revision: str | None = "5cd7eee3549d"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"bill_topic",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"support_position",
|
||||||
|
sa.Enum("for", "against", name="bill_topic_position", native_enum=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_bill_topic_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_topic")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_id",
|
||||||
|
"topic",
|
||||||
|
"support_position",
|
||||||
|
name="uq_bill_topic_bill_id_topic_support_position",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_bill_topic_topic", "bill_topic", ["topic"], unique=False, schema=schema
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator_score",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("year", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column("score", sa.Float(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_score_legislator_id"),
|
||||||
|
"legislator_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_score_year_topic",
|
||||||
|
"legislator_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"legislator_bill_score",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("bill_topic_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("year", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.String(), nullable=False),
|
||||||
|
sa.Column("score", sa.Float(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_topic_id"],
|
||||||
|
[f"{schema}.bill_topic.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_topic_id_bill_topic"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_bill_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_topic_id",
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
name="uq_legislator_bill_score_bill_topic_id_legislator_id_year",
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_topic_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_legislator_bill_score_year_topic",
|
||||||
|
"legislator_bill_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill",
|
||||||
|
sa.Column("score_processed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text", sa.Column("summary", sa.String(), nullable=True), schema=schema
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("bill_text", "summary", schema=schema)
|
||||||
|
op.drop_column("bill", "score_processed_at", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_legislator_bill_score_year_topic",
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_bill_score", schema=schema)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_legislator_score_year_topic", table_name="legislator_score", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_score_legislator_id"),
|
||||||
|
table_name="legislator_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_score", schema=schema)
|
||||||
|
op.drop_index("ix_bill_topic_topic", table_name="bill_topic", schema=schema)
|
||||||
|
op.drop_table("bill_topic", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
"""removed LegislatorBillScore.
|
||||||
|
|
||||||
|
Revision ID: b63ed11d6775
|
||||||
|
Revises: 7d15f9b7c8a2
|
||||||
|
Create Date: 2026-04-21 22:46:48.058542
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b63ed11d6775"
|
||||||
|
down_revision: str | None = "7d15f9b7c8a2"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_legislator_bill_score_year_topic"),
|
||||||
|
table_name="legislator_bill_score",
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.drop_table("legislator_bill_score", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_bill_score",
|
||||||
|
sa.Column("bill_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("bill_topic_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("year", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("topic", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"score",
|
||||||
|
sa.DOUBLE_PRECISION(precision=53),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created",
|
||||||
|
postgresql.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated",
|
||||||
|
postgresql.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"],
|
||||||
|
[f"{schema}.bill.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_id_bill"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_topic_id"],
|
||||||
|
[f"{schema}.bill_topic.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_bill_topic_id_bill_topic"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_bill_score_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_bill_score")),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"bill_topic_id",
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
name=op.f("uq_legislator_bill_score_bill_topic_id_legislator_id_year"),
|
||||||
|
postgresql_include=[],
|
||||||
|
postgresql_nulls_not_distinct=False,
|
||||||
|
),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_year_topic"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["year", "topic"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_legislator_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["legislator_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_topic_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_topic_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_legislator_bill_score_bill_id"),
|
||||||
|
"legislator_bill_score",
|
||||||
|
["bill_id"],
|
||||||
|
unique=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
"""add bill_text summarization metadata.
|
||||||
|
|
||||||
|
Revision ID: 7d15f9b7c8a2
|
||||||
|
Revises: ef4bc5411176
|
||||||
|
Create Date: 2026-04-22 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "7d15f9b7c8a2"
|
||||||
|
down_revision: str | None = "ef4bc5411176"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_model", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
op.drop_column(
|
||||||
|
"bill_text", "summarization_system_prompt_version", schema=schema
|
||||||
|
)
|
||||||
|
op.drop_column("bill_text", "summarization_user_prompt_version", schema=schema)
|
||||||
|
op.drop_column("bill_text", "summarization_model", schema=schema)
|
||||||
138
alembic/env.py
Normal file
138
alembic/env.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Alembic."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from alembic.script import write_hooks
|
||||||
|
from sqlalchemy.schema import CreateSchema
|
||||||
|
|
||||||
|
from pipelines.common import bash_wrapper
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
base_class: type[DeclarativeBase] = config.attributes.get("base")
|
||||||
|
if base_class is None:
|
||||||
|
error = "No base class provided. Use the database CLI to run alembic commands."
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
target_metadata = base_class.metadata
|
||||||
|
logging.basicConfig(
|
||||||
|
level="DEBUG",
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
||||||
|
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("dynamic_schema")
|
||||||
|
def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Dynamic schema."""
|
||||||
|
original_file = Path(filename).read_text()
|
||||||
|
schema_name = base_class.schema_name
|
||||||
|
dynamic_schema_file_part1 = original_file.replace(
|
||||||
|
f"schema='{schema_name}'", "schema=schema"
|
||||||
|
)
|
||||||
|
dynamic_schema_file = dynamic_schema_file_part1.replace(
|
||||||
|
f"'{schema_name}.", "f'{schema}."
|
||||||
|
)
|
||||||
|
Path(filename).write_text(dynamic_schema_file)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("import_postgresql")
|
||||||
|
def import_postgresql(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Add postgresql dialect import when postgresql types are used."""
|
||||||
|
content = Path(filename).read_text()
|
||||||
|
if (
|
||||||
|
"postgresql." in content
|
||||||
|
and "from sqlalchemy.dialects import postgresql" not in content
|
||||||
|
):
|
||||||
|
content = content.replace(
|
||||||
|
"import sqlalchemy as sa\n",
|
||||||
|
"import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n",
|
||||||
|
)
|
||||||
|
Path(filename).write_text(content)
|
||||||
|
|
||||||
|
|
||||||
|
@write_hooks.register("ruff")
|
||||||
|
def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
|
"""Docstring for ruff_check_and_format."""
|
||||||
|
bash_wrapper(f"ruff check --fix {filename}")
|
||||||
|
bash_wrapper(f"ruff format {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def include_name(
|
||||||
|
name: str | None,
|
||||||
|
type_: Literal[
|
||||||
|
"schema",
|
||||||
|
"table",
|
||||||
|
"column",
|
||||||
|
"index",
|
||||||
|
"unique_constraint",
|
||||||
|
"foreign_key_constraint",
|
||||||
|
],
|
||||||
|
_parent_names: MutableMapping[
|
||||||
|
Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None
|
||||||
|
],
|
||||||
|
) -> bool:
|
||||||
|
"""Filter tables to be included in the migration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the table.
|
||||||
|
type_ (str): The type of the table.
|
||||||
|
_parent_names (MutableMapping): The names of the parent tables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the table should be included, False otherwise.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if type_ == "schema":
|
||||||
|
# allows a database with multiple schemas to have separate alembic revisions
|
||||||
|
return name == target_metadata.schema
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
env_prefix = config.attributes.get("env_prefix", "POSTGRES")
|
||||||
|
connectable = get_postgres_engine(name=env_prefix)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
schema = base_class.schema_name
|
||||||
|
if not connectable.dialect.has_schema(connection, schema):
|
||||||
|
answer = input(f"Schema {schema!r} does not exist. Create it? [y/N] ")
|
||||||
|
if answer.lower() != "y":
|
||||||
|
error = f"Schema {schema!r} does not exist. Exiting."
|
||||||
|
raise SystemExit(error)
|
||||||
|
connection.execute(CreateSchema(schema))
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
include_schemas=True,
|
||||||
|
version_table_schema=schema,
|
||||||
|
include_name=include_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
run_migrations_online()
|
||||||
36
alembic/script.py.mako
Normal file
36
alembic/script.py.mako
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""${message}.
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pipelines.orm import ${config.attributes["base"].__name__}
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: str | None = ${repr(down_revision)}
|
||||||
|
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
|
||||||
|
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
|
||||||
|
|
||||||
|
schema=${config.attributes["base"].__name__}.schema_name
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
123
database_cli.py
Normal file
123
database_cli.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""CLI wrapper around alembic for multi-database support.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
database <db_name> <command> [args...]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
database van_inventory upgrade head
|
||||||
|
database van_inventory downgrade head-1
|
||||||
|
database van_inventory revision --autogenerate -m "add meals table"
|
||||||
|
database van_inventory check
|
||||||
|
database richie check
|
||||||
|
database richie upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from alembic.config import CommandLine, Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatabaseConfig:
|
||||||
|
"""Configuration for a database."""
|
||||||
|
|
||||||
|
env_prefix: str
|
||||||
|
version_location: str
|
||||||
|
base_module: str
|
||||||
|
base_class_name: str
|
||||||
|
models_module: str
|
||||||
|
script_location: str = "alembic"
|
||||||
|
file_template: str = "%%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s"
|
||||||
|
|
||||||
|
def get_base(self) -> type[DeclarativeBase]:
|
||||||
|
"""Import and return the Base class."""
|
||||||
|
module = import_module(self.base_module)
|
||||||
|
return getattr(module, self.base_class_name)
|
||||||
|
|
||||||
|
def import_models(self) -> None:
|
||||||
|
"""Import ORM models so alembic autogenerate can detect them."""
|
||||||
|
import_module(self.models_module)
|
||||||
|
|
||||||
|
def alembic_config(self) -> Config:
|
||||||
|
"""Build an alembic Config for this database."""
|
||||||
|
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
|
||||||
|
from alembic.config import Config as AlembicConfig # noqa: PLC0415
|
||||||
|
|
||||||
|
cfg = AlembicConfig()
|
||||||
|
cfg.set_main_option("script_location", self.script_location)
|
||||||
|
cfg.set_main_option("file_template", self.file_template)
|
||||||
|
cfg.set_main_option("prepend_sys_path", ".")
|
||||||
|
cfg.set_main_option("version_path_separator", "os")
|
||||||
|
cfg.set_main_option("version_locations", self.version_location)
|
||||||
|
cfg.set_main_option("revision_environment", "true")
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "hooks", "dynamic_schema,import_postgresql,ruff"
|
||||||
|
)
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "dynamic_schema.type", "dynamic_schema"
|
||||||
|
)
|
||||||
|
cfg.set_section_option(
|
||||||
|
"post_write_hooks", "import_postgresql.type", "import_postgresql"
|
||||||
|
)
|
||||||
|
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
|
||||||
|
cfg.attributes["base"] = self.get_base()
|
||||||
|
cfg.attributes["env_prefix"] = self.env_prefix
|
||||||
|
self.import_models()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
DATABASES: dict[str, DatabaseConfig] = {
|
||||||
|
"data_science_dev": DatabaseConfig(
|
||||||
|
env_prefix="DATA_SCIENCE_DEV",
|
||||||
|
version_location="alembic/data_science_dev/versions",
|
||||||
|
base_module="pipelines.orm.data_science_dev.base",
|
||||||
|
base_class_name="DataScienceDevBase",
|
||||||
|
models_module="pipelines.orm.data_science_dev.models",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Multi-database alembic wrapper.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(
|
||||||
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def main(
|
||||||
|
ctx: typer.Context,
|
||||||
|
db_name: Annotated[
|
||||||
|
str, typer.Argument(help=f"Database name. Options: {', '.join(DATABASES)}")
|
||||||
|
],
|
||||||
|
command: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(
|
||||||
|
help="Alembic command (upgrade, downgrade, revision, check, etc.)"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Run an alembic command against the specified database."""
|
||||||
|
db_config = DATABASES.get(db_name)
|
||||||
|
if not db_config:
|
||||||
|
typer.echo(
|
||||||
|
f"Unknown database: {db_name!r}. Available: {', '.join(DATABASES)}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
alembic_cfg = db_config.alembic_config()
|
||||||
|
|
||||||
|
cmd_line = CommandLine()
|
||||||
|
options = cmd_line.parser.parse_args([command, *ctx.args])
|
||||||
|
cmd_line.run_cmd(alembic_cfg, options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
1
pipelines/__init__.py
Normal file
1
pipelines/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Init."""
|
||||||
72
pipelines/common.py
Normal file
72
pipelines/common.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""common."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from os import getenv
|
||||||
|
from subprocess import PIPE, Popen
|
||||||
|
|
||||||
|
from apprise import Apprise
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logger(level: str = "INFO") -> None:
|
||||||
|
"""Configure the logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level (str, optional): The logging level. Defaults to "INFO".
|
||||||
|
"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
||||||
|
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bash_wrapper(command: str) -> tuple[str, int]:
|
||||||
|
"""Execute a bash command and capture the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command (str): The bash command to be executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, int]: A tuple containing the output of the command (stdout) as a string,
|
||||||
|
the error output (stderr) as a string (optional), and the return code as an integer.
|
||||||
|
"""
|
||||||
|
# This is a acceptable risk
|
||||||
|
process = Popen(command.split(), stdout=PIPE, stderr=PIPE)
|
||||||
|
output, error = process.communicate()
|
||||||
|
if error:
|
||||||
|
logger.error(f"{error=}")
|
||||||
|
return error.decode(), process.returncode
|
||||||
|
|
||||||
|
return output.decode(), process.returncode
|
||||||
|
|
||||||
|
|
||||||
|
def signal_alert(body: str, title: str = "") -> None:
|
||||||
|
"""Send a signal alert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body (str): The body of the alert.
|
||||||
|
title (str, optional): The title of the alert. Defaults to "".
|
||||||
|
"""
|
||||||
|
apprise_client = Apprise()
|
||||||
|
|
||||||
|
from_phone = getenv("SIGNAL_ALERT_FROM_PHONE")
|
||||||
|
to_phone = getenv("SIGNAL_ALERT_TO_PHONE")
|
||||||
|
if not from_phone or not to_phone:
|
||||||
|
logger.info("SIGNAL_ALERT_FROM_PHONE or SIGNAL_ALERT_TO_PHONE not set")
|
||||||
|
return
|
||||||
|
|
||||||
|
apprise_client.add(f"signal://localhost:8989/{from_phone}/{to_phone}")
|
||||||
|
|
||||||
|
apprise_client.notify(title=title, body=body)
|
||||||
|
|
||||||
|
|
||||||
|
def utcnow() -> datetime:
|
||||||
|
"""Get the current UTC time."""
|
||||||
|
return datetime.now(tz=UTC)
|
||||||
@@ -69,8 +69,9 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
|
|
||||||
def get_config_dir() -> Path:
|
def get_config_dir() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config directory."""
|
||||||
return Path(__file__).resolve().parent.parent.parent / "config"
|
return Path(__file__).resolve().parents[2] / "config"
|
||||||
|
|
||||||
|
|
||||||
def default_config_path() -> Path:
|
def default_config_path() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config file."""
|
||||||
671
pipelines/ingest_congress.py
Normal file
671
pipelines/ingest_congress.py
Normal file
@@ -0,0 +1,671 @@
|
|||||||
|
"""Ingestion pipeline for loading congress data from unitedstates/congress JSON files.
|
||||||
|
|
||||||
|
Loads legislators, bills, votes, vote records, and bill text into the data_science_dev database.
|
||||||
|
Expects the parent directory to contain congress-tracker/ and congress-legislators/ as siblings.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ingest-congress /path/to/parent/
|
||||||
|
ingest-congress /path/to/parent/ --congress 118
|
||||||
|
ingest-congress /path/to/parent/ --congress 118 --only bills
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path # noqa: TC003 needed at runtime for typer CLI argument
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import typer
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.pipelines.common import configure_logger
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
Legislator,
|
||||||
|
LegislatorSocialMedia,
|
||||||
|
Vote,
|
||||||
|
VoteRecord,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BATCH_SIZE = 10_000
|
||||||
|
|
||||||
|
app = typer.Typer(help="Ingest unitedstates/congress data into data_science_dev.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
parent_dir: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Argument(
|
||||||
|
help="Parent directory containing congress-tracker/ and congress-legislators/"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only ingest a specific congress number")
|
||||||
|
] = None,
|
||||||
|
only: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option(
|
||||||
|
help="Only run a specific step: legislators, social-media, bills, votes, bill-text"
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Ingest congress data from unitedstates/congress JSON files."""
|
||||||
|
configure_logger(level="INFO")
|
||||||
|
|
||||||
|
data_dir = parent_dir / "congress-tracker/congress/data/"
|
||||||
|
legislators_dir = parent_dir / "congress-legislators"
|
||||||
|
|
||||||
|
if not data_dir.is_dir():
|
||||||
|
typer.echo(f"Expected congress-tracker/ directory: {data_dir}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
if not legislators_dir.is_dir():
|
||||||
|
typer.echo(
|
||||||
|
f"Expected congress-legislators/ directory: {legislators_dir}", err=True
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
|
||||||
|
congress_dirs = _resolve_congress_dirs(data_dir, congress)
|
||||||
|
if not congress_dirs:
|
||||||
|
typer.echo("No congress directories found.", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
logger.info("Found %d congress directories to process", len(congress_dirs))
|
||||||
|
|
||||||
|
steps: dict[str, tuple] = {
|
||||||
|
"legislators": (ingest_legislators, (engine, legislators_dir)),
|
||||||
|
"legislators-social-media": (ingest_social_media, (engine, legislators_dir)),
|
||||||
|
"bills": (ingest_bills, (engine, congress_dirs)),
|
||||||
|
"votes": (ingest_votes, (engine, congress_dirs)),
|
||||||
|
"bill-text": (ingest_bill_text, (engine, congress_dirs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if only:
|
||||||
|
if only not in steps:
|
||||||
|
typer.echo(
|
||||||
|
f"Unknown step: {only}. Choose from: {', '.join(steps)}", err=True
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
steps = {only: steps[only]}
|
||||||
|
|
||||||
|
for step_name, (step_func, step_args) in steps.items():
|
||||||
|
logger.info("=== Starting step: %s ===", step_name)
|
||||||
|
step_func(*step_args)
|
||||||
|
logger.info("=== Finished step: %s ===", step_name)
|
||||||
|
|
||||||
|
logger.info("ingest-congress done")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_congress_dirs(data_dir: Path, congress: int | None) -> list[Path]:
|
||||||
|
"""Find congress number directories under data_dir."""
|
||||||
|
if congress is not None:
|
||||||
|
target = data_dir / str(congress)
|
||||||
|
return [target] if target.is_dir() else []
|
||||||
|
return sorted(
|
||||||
|
path for path in data_dir.iterdir() if path.is_dir() and path.name.isdigit()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_batch(session: Session, batch: list[object], label: str) -> int:
|
||||||
|
"""Add a batch of ORM objects to the session and commit. Returns count added."""
|
||||||
|
if not batch:
|
||||||
|
return 0
|
||||||
|
session.add_all(batch)
|
||||||
|
session.commit()
|
||||||
|
count = len(batch)
|
||||||
|
logger.info("Committed %d %s", count, label)
|
||||||
|
batch.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legislators — loaded from congress-legislators YAML files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_legislators(engine: Engine, legislators_dir: Path) -> None:
|
||||||
|
"""Load legislators from congress-legislators YAML files."""
|
||||||
|
legislators_data = _load_legislators_yaml(legislators_dir)
|
||||||
|
logger.info("Loaded %d legislators from YAML files", len(legislators_data))
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing_legislators = {
|
||||||
|
legislator.bioguide_id: legislator
|
||||||
|
for legislator in session.scalars(select(Legislator)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing legislators in DB", len(existing_legislators))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
total_updated = 0
|
||||||
|
for entry in legislators_data:
|
||||||
|
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fields = _parse_legislator(entry)
|
||||||
|
if existing := existing_legislators.get(bioguide_id):
|
||||||
|
changed = False
|
||||||
|
for field, value in fields.items():
|
||||||
|
if value is not None and getattr(existing, field) != value:
|
||||||
|
setattr(existing, field, value)
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
total_updated += 1
|
||||||
|
else:
|
||||||
|
session.add(Legislator(bioguide_id=bioguide_id, **fields))
|
||||||
|
total_inserted += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Inserted %d new legislators, updated %d existing",
|
||||||
|
total_inserted,
|
||||||
|
total_updated,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_legislators_yaml(legislators_dir: Path) -> list[dict]:
|
||||||
|
"""Load and combine legislators-current.yaml and legislators-historical.yaml."""
|
||||||
|
legislators: list[dict] = []
|
||||||
|
for filename in ("legislators-current.yaml", "legislators-historical.yaml"):
|
||||||
|
path = legislators_dir / filename
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning("Legislators file not found: %s", path)
|
||||||
|
continue
|
||||||
|
with path.open() as file:
|
||||||
|
data = yaml.safe_load(file)
|
||||||
|
if isinstance(data, list):
|
||||||
|
legislators.extend(data)
|
||||||
|
return legislators
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_legislator(entry: dict) -> dict:
|
||||||
|
"""Extract Legislator fields from a congress-legislators YAML entry."""
|
||||||
|
ids = entry.get("id", {})
|
||||||
|
name = entry.get("name", {})
|
||||||
|
bio = entry.get("bio", {})
|
||||||
|
terms = entry.get("terms", [])
|
||||||
|
latest_term = terms[-1] if terms else {}
|
||||||
|
|
||||||
|
fec_ids = ids.get("fec")
|
||||||
|
fec_ids_joined = ",".join(fec_ids) if isinstance(fec_ids, list) else fec_ids
|
||||||
|
|
||||||
|
chamber = latest_term.get("type")
|
||||||
|
chamber_normalized = {"rep": "House", "sen": "Senate"}.get(chamber, chamber)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thomas_id": ids.get("thomas"),
|
||||||
|
"lis_id": ids.get("lis"),
|
||||||
|
"govtrack_id": ids.get("govtrack"),
|
||||||
|
"opensecrets_id": ids.get("opensecrets"),
|
||||||
|
"fec_ids": fec_ids_joined,
|
||||||
|
"first_name": name.get("first"),
|
||||||
|
"last_name": name.get("last"),
|
||||||
|
"official_full_name": name.get("official_full"),
|
||||||
|
"nickname": name.get("nickname"),
|
||||||
|
"birthday": bio.get("birthday"),
|
||||||
|
"gender": bio.get("gender"),
|
||||||
|
"current_party": latest_term.get("party"),
|
||||||
|
"current_state": latest_term.get("state"),
|
||||||
|
"current_district": latest_term.get("district"),
|
||||||
|
"current_chamber": chamber_normalized,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Social Media — loaded from legislators-social-media.yaml
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SOCIAL_MEDIA_PLATFORMS = {
|
||||||
|
"twitter": "https://twitter.com/{account}",
|
||||||
|
"facebook": "https://facebook.com/{account}",
|
||||||
|
"youtube": "https://youtube.com/{account}",
|
||||||
|
"instagram": "https://instagram.com/{account}",
|
||||||
|
"mastodon": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_social_media(engine: Engine, legislators_dir: Path) -> None:
|
||||||
|
"""Load social media accounts from legislators-social-media.yaml."""
|
||||||
|
social_media_path = legislators_dir / "legislators-social-media.yaml"
|
||||||
|
if not social_media_path.exists():
|
||||||
|
logger.warning("Social media file not found: %s", social_media_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
with social_media_path.open() as file:
|
||||||
|
social_media_data = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if not isinstance(social_media_data, list):
|
||||||
|
logger.warning("Unexpected format in %s", social_media_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d entries from legislators-social-media.yaml", len(social_media_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
legislator_map = _build_legislator_map(session)
|
||||||
|
existing_accounts = {
|
||||||
|
(account.legislator_id, account.platform)
|
||||||
|
for account in session.scalars(select(LegislatorSocialMedia)).all()
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Found %d existing social media accounts in DB", len(existing_accounts)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
total_updated = 0
|
||||||
|
for entry in social_media_data:
|
||||||
|
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
legislator_id = legislator_map.get(bioguide_id)
|
||||||
|
if legislator_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
social = entry.get("social", {})
|
||||||
|
for platform, url_template in SOCIAL_MEDIA_PLATFORMS.items():
|
||||||
|
account_name = social.get(platform)
|
||||||
|
if not account_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = (
|
||||||
|
url_template.format(account=account_name) if url_template else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if (legislator_id, platform) in existing_accounts:
|
||||||
|
total_updated += 1
|
||||||
|
else:
|
||||||
|
session.add(
|
||||||
|
LegislatorSocialMedia(
|
||||||
|
legislator_id=legislator_id,
|
||||||
|
platform=platform,
|
||||||
|
account_name=str(account_name),
|
||||||
|
url=url,
|
||||||
|
source="https://github.com/unitedstates/congress-legislators",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_accounts.add((legislator_id, platform))
|
||||||
|
total_inserted += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Inserted %d new social media accounts, updated %d existing",
|
||||||
|
total_inserted,
|
||||||
|
total_updated,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_voters(position_group: object) -> Iterator[dict]:
|
||||||
|
"""Yield voter dicts from a vote position group (handles list, single dict, or string)."""
|
||||||
|
if isinstance(position_group, dict):
|
||||||
|
yield position_group
|
||||||
|
elif isinstance(position_group, list):
|
||||||
|
for voter in position_group:
|
||||||
|
if isinstance(voter, dict):
|
||||||
|
yield voter
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bills
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_bills(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load bill data.json files."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing_bills = {
|
||||||
|
(bill.congress, bill.bill_type, bill.number)
|
||||||
|
for bill in session.scalars(select(Bill)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing bills in DB", len(existing_bills))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[Bill] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
bills_dir = congress_dir / "bills"
|
||||||
|
if not bills_dir.is_dir():
|
||||||
|
continue
|
||||||
|
logger.info("Scanning bills from %s", congress_dir.name)
|
||||||
|
for bill_file in bills_dir.rglob("data.json"):
|
||||||
|
data = _read_json(bill_file)
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
bill = _parse_bill(data, existing_bills)
|
||||||
|
if bill is not None:
|
||||||
|
batch.append(bill)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "bills")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "bills")
|
||||||
|
logger.info("Inserted %d new bills total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bill(data: dict, existing_bills: set[tuple[int, str, int]]) -> Bill | None:
|
||||||
|
"""Parse a bill data.json dict into a Bill ORM object, skipping existing."""
|
||||||
|
raw_congress = data.get("congress")
|
||||||
|
bill_type = data.get("bill_type")
|
||||||
|
raw_number = data.get("number")
|
||||||
|
if raw_congress is None or bill_type is None or raw_number is None:
|
||||||
|
return None
|
||||||
|
congress = int(raw_congress)
|
||||||
|
number = int(raw_number)
|
||||||
|
if (congress, bill_type, number) in existing_bills:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sponsor_bioguide = None
|
||||||
|
sponsor = data.get("sponsor")
|
||||||
|
if sponsor:
|
||||||
|
sponsor_bioguide = sponsor.get("bioguide_id")
|
||||||
|
|
||||||
|
return Bill(
|
||||||
|
congress=congress,
|
||||||
|
bill_type=bill_type,
|
||||||
|
number=number,
|
||||||
|
title=data.get("short_title") or data.get("official_title"),
|
||||||
|
title_short=data.get("short_title"),
|
||||||
|
official_title=data.get("official_title"),
|
||||||
|
status=data.get("status"),
|
||||||
|
status_at=data.get("status_at"),
|
||||||
|
sponsor_bioguide_id=sponsor_bioguide,
|
||||||
|
subjects_top_term=data.get("subjects_top_term"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Votes (and vote records)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_votes(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load vote data.json files with their vote records."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
legislator_map = _build_legislator_map(session)
|
||||||
|
logger.info("Loaded %d legislators into lookup map", len(legislator_map))
|
||||||
|
bill_map = _build_bill_map(session)
|
||||||
|
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||||
|
existing_votes = {
|
||||||
|
(vote.congress, vote.chamber, vote.session, vote.number)
|
||||||
|
for vote in session.scalars(select(Vote)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing votes in DB", len(existing_votes))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[Vote] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
votes_dir = congress_dir / "votes"
|
||||||
|
if not votes_dir.is_dir():
|
||||||
|
continue
|
||||||
|
logger.info("Scanning votes from %s", congress_dir.name)
|
||||||
|
for vote_file in votes_dir.rglob("data.json"):
|
||||||
|
data = _read_json(vote_file)
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
vote = _parse_vote(data, legislator_map, bill_map, existing_votes)
|
||||||
|
if vote is not None:
|
||||||
|
batch.append(vote)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "votes")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "votes")
|
||||||
|
logger.info("Inserted %d new votes total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_legislator_map(session: Session) -> dict[str, int]:
|
||||||
|
"""Build a mapping of bioguide_id -> legislator.id."""
|
||||||
|
return {
|
||||||
|
legislator.bioguide_id: legislator.id
|
||||||
|
for legislator in session.scalars(select(Legislator)).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bill_map(session: Session) -> dict[tuple[int, str, int], int]:
|
||||||
|
"""Build a mapping of (congress, bill_type, number) -> bill.id."""
|
||||||
|
return {
|
||||||
|
(bill.congress, bill.bill_type, bill.number): bill.id
|
||||||
|
for bill in session.scalars(select(Bill)).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_vote(
|
||||||
|
data: dict,
|
||||||
|
legislator_map: dict[str, int],
|
||||||
|
bill_map: dict[tuple[int, str, int], int],
|
||||||
|
existing_votes: set[tuple[int, str, int, int]],
|
||||||
|
) -> Vote | None:
|
||||||
|
"""Parse a vote data.json dict into a Vote ORM object with records."""
|
||||||
|
raw_congress = data.get("congress")
|
||||||
|
chamber = data.get("chamber")
|
||||||
|
raw_number = data.get("number")
|
||||||
|
vote_date = data.get("date")
|
||||||
|
if (
|
||||||
|
raw_congress is None
|
||||||
|
or chamber is None
|
||||||
|
or raw_number is None
|
||||||
|
or vote_date is None
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_session = data.get("session")
|
||||||
|
if raw_session is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
congress = int(raw_congress)
|
||||||
|
number = int(raw_number)
|
||||||
|
session_number = int(raw_session)
|
||||||
|
|
||||||
|
# Normalize chamber from "h"/"s" to "House"/"Senate"
|
||||||
|
chamber_normalized = {"h": "House", "s": "Senate"}.get(chamber, chamber)
|
||||||
|
|
||||||
|
if (congress, chamber_normalized, session_number, number) in existing_votes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Resolve linked bill
|
||||||
|
bill_id = None
|
||||||
|
bill_ref = data.get("bill")
|
||||||
|
if bill_ref:
|
||||||
|
bill_key = (
|
||||||
|
int(bill_ref.get("congress", congress)),
|
||||||
|
bill_ref.get("type"),
|
||||||
|
int(bill_ref.get("number", 0)),
|
||||||
|
)
|
||||||
|
bill_id = bill_map.get(bill_key)
|
||||||
|
|
||||||
|
raw_votes = data.get("votes", {})
|
||||||
|
vote_counts = _count_votes(raw_votes)
|
||||||
|
vote_records = _build_vote_records(raw_votes, legislator_map)
|
||||||
|
|
||||||
|
return Vote(
|
||||||
|
congress=congress,
|
||||||
|
chamber=chamber_normalized,
|
||||||
|
session=session_number,
|
||||||
|
number=number,
|
||||||
|
vote_type=data.get("type"),
|
||||||
|
question=data.get("question"),
|
||||||
|
result=data.get("result"),
|
||||||
|
result_text=data.get("result_text"),
|
||||||
|
vote_date=vote_date[:10] if isinstance(vote_date, str) else vote_date,
|
||||||
|
bill_id=bill_id,
|
||||||
|
vote_records=vote_records,
|
||||||
|
**vote_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_votes(raw_votes: dict) -> dict[str, int]:
|
||||||
|
"""Count voters per position category, correctly handling dict and list formats."""
|
||||||
|
yea_count = 0
|
||||||
|
nay_count = 0
|
||||||
|
not_voting_count = 0
|
||||||
|
present_count = 0
|
||||||
|
|
||||||
|
for position, position_group in raw_votes.items():
|
||||||
|
voter_count = sum(1 for _ in _iter_voters(position_group))
|
||||||
|
if position in ("Yea", "Aye"):
|
||||||
|
yea_count += voter_count
|
||||||
|
elif position in ("Nay", "No"):
|
||||||
|
nay_count += voter_count
|
||||||
|
elif position == "Not Voting":
|
||||||
|
not_voting_count += voter_count
|
||||||
|
elif position == "Present":
|
||||||
|
present_count += voter_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"yea_count": yea_count,
|
||||||
|
"nay_count": nay_count,
|
||||||
|
"not_voting_count": not_voting_count,
|
||||||
|
"present_count": present_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_vote_records(
|
||||||
|
raw_votes: dict, legislator_map: dict[str, int]
|
||||||
|
) -> list[VoteRecord]:
|
||||||
|
"""Build VoteRecord objects from raw vote data."""
|
||||||
|
records: list[VoteRecord] = []
|
||||||
|
for position, position_group in raw_votes.items():
|
||||||
|
for voter in _iter_voters(position_group):
|
||||||
|
bioguide_id = voter.get("id")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
legislator_id = legislator_map.get(bioguide_id)
|
||||||
|
if legislator_id is None:
|
||||||
|
continue
|
||||||
|
records.append(
|
||||||
|
VoteRecord(
|
||||||
|
legislator_id=legislator_id,
|
||||||
|
position=position,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bill Text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_bill_text(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load bill text from text-versions directories."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
bill_map = _build_bill_map(session)
|
||||||
|
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||||
|
existing_bill_texts = {
|
||||||
|
(bill_text.bill_id, bill_text.version_code)
|
||||||
|
for bill_text in session.scalars(select(BillText)).all()
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Found %d existing bill text versions in DB", len(existing_bill_texts)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[BillText] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
logger.info("Scanning bill texts from %s", congress_dir.name)
|
||||||
|
for bill_text in _iter_bill_texts(
|
||||||
|
congress_dir, bill_map, existing_bill_texts
|
||||||
|
):
|
||||||
|
batch.append(bill_text)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||||
|
logger.info("Inserted %d new bill text versions total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_bill_texts(
|
||||||
|
congress_dir: Path,
|
||||||
|
bill_map: dict[tuple[int, str, int], int],
|
||||||
|
existing_bill_texts: set[tuple[int, str]],
|
||||||
|
) -> Iterator[BillText]:
|
||||||
|
"""Yield BillText objects for a single congress directory, skipping existing."""
|
||||||
|
bills_dir = congress_dir / "bills"
|
||||||
|
if not bills_dir.is_dir():
|
||||||
|
return
|
||||||
|
|
||||||
|
for bill_dir in bills_dir.rglob("text-versions"):
|
||||||
|
if not bill_dir.is_dir():
|
||||||
|
continue
|
||||||
|
bill_key = _bill_key_from_dir(bill_dir.parent, congress_dir)
|
||||||
|
if bill_key is None:
|
||||||
|
continue
|
||||||
|
bill_id = bill_map.get(bill_key)
|
||||||
|
if bill_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for version_dir in sorted(bill_dir.iterdir()):
|
||||||
|
if not version_dir.is_dir():
|
||||||
|
continue
|
||||||
|
if (bill_id, version_dir.name) in existing_bill_texts:
|
||||||
|
continue
|
||||||
|
text_content = _read_bill_text(version_dir)
|
||||||
|
version_data = _read_json(version_dir / "data.json")
|
||||||
|
yield BillText(
|
||||||
|
bill_id=bill_id,
|
||||||
|
version_code=version_dir.name,
|
||||||
|
version_name=version_data.get("version_name") if version_data else None,
|
||||||
|
date=version_data.get("issued_on") if version_data else None,
|
||||||
|
text_content=text_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bill_key_from_dir(
|
||||||
|
bill_dir: Path, congress_dir: Path
|
||||||
|
) -> tuple[int, str, int] | None:
|
||||||
|
"""Extract (congress, bill_type, number) from directory structure."""
|
||||||
|
congress = int(congress_dir.name)
|
||||||
|
bill_type = bill_dir.parent.name
|
||||||
|
name = bill_dir.name
|
||||||
|
# Directory name is like "hr3590" — strip the type prefix to get the number
|
||||||
|
number_str = name[len(bill_type) :]
|
||||||
|
if not number_str.isdigit():
|
||||||
|
return None
|
||||||
|
return (congress, bill_type, int(number_str))
|
||||||
|
|
||||||
|
|
||||||
|
def _read_bill_text(version_dir: Path) -> str | None:
|
||||||
|
"""Read bill text from a version directory, preferring .txt over .xml."""
|
||||||
|
for extension in ("txt", "htm", "html", "xml"):
|
||||||
|
candidates = list(version_dir.glob(f"document.{extension}"))
|
||||||
|
if not candidates:
|
||||||
|
candidates = list(version_dir.glob(f"*.{extension}"))
|
||||||
|
if candidates:
|
||||||
|
try:
|
||||||
|
return candidates[0].read_text(encoding="utf-8")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to read %s", candidates[0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _read_json(path: Path) -> dict | None:
|
||||||
|
"""Read and parse a JSON file, returning None on failure."""
|
||||||
|
try:
|
||||||
|
return orjson.loads(path.read_bytes())
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to parse %s", path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
281
pipelines/ingest_posts.py
Normal file
281
pipelines/ingest_posts.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ingest-posts /path/to/files/
|
||||||
|
ingest-posts /path/to/single_file.jsonl
|
||||||
|
ingest-posts /data/dir/ --workers 4 --batch-size 5000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path # noqa: TC003 this is needed for typer
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import psycopg
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pipelines.pipelines.common import configure_logger
|
||||||
|
from pipelines.orm.common import get_connection_info
|
||||||
|
from pipelines.pipelines.parallelize import parallelize_process
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Argument(help="Directory containing JSONL files, or a single JSONL file"),
|
||||||
|
],
|
||||||
|
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
|
||||||
|
workers: Annotated[
|
||||||
|
int, typer.Option(help="Parallel workers for multi-file ingestion")
|
||||||
|
] = 4,
|
||||||
|
pattern: Annotated[
|
||||||
|
str, typer.Option(help="Glob pattern for JSONL files")
|
||||||
|
] = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest JSONL post files into the weekly-partitioned posts table."""
|
||||||
|
configure_logger(level="INFO")
|
||||||
|
|
||||||
|
logger.info("starting ingest-posts")
|
||||||
|
logger.info(
|
||||||
|
"path=%s batch_size=%d workers=%d pattern=%s",
|
||||||
|
path,
|
||||||
|
batch_size,
|
||||||
|
workers,
|
||||||
|
pattern,
|
||||||
|
)
|
||||||
|
if path.is_file():
|
||||||
|
ingest_file(path, batch_size=batch_size)
|
||||||
|
elif path.is_dir():
|
||||||
|
ingest_directory(
|
||||||
|
path, batch_size=batch_size, max_workers=workers, pattern=pattern
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
typer.echo(f"Path does not exist: {path}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
logger.info("ingest-posts done")
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_directory(
|
||||||
|
directory: Path,
|
||||||
|
*,
|
||||||
|
batch_size: int,
|
||||||
|
max_workers: int,
|
||||||
|
pattern: str = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest all JSONL files in a directory using parallel workers."""
|
||||||
|
files = sorted(directory.glob(pattern))
|
||||||
|
if not files:
|
||||||
|
logger.warning("No JSONL files found in %s", directory)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Found %d JSONL files to ingest", len(files))
|
||||||
|
|
||||||
|
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
|
||||||
|
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA = "main"
|
||||||
|
|
||||||
|
COLUMNS = (
|
||||||
|
"post_id",
|
||||||
|
"user_id",
|
||||||
|
"instance",
|
||||||
|
"date",
|
||||||
|
"text",
|
||||||
|
"langs",
|
||||||
|
"like_count",
|
||||||
|
"reply_count",
|
||||||
|
"repost_count",
|
||||||
|
"reply_to",
|
||||||
|
"replied_author",
|
||||||
|
"thread_root",
|
||||||
|
"thread_root_author",
|
||||||
|
"repost_from",
|
||||||
|
"reposted_author",
|
||||||
|
"quotes",
|
||||||
|
"quoted_author",
|
||||||
|
"labels",
|
||||||
|
"sent_label",
|
||||||
|
"sent_score",
|
||||||
|
)
|
||||||
|
|
||||||
|
INSERT_FROM_STAGING = f"""
|
||||||
|
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
|
||||||
|
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
|
||||||
|
ON CONFLICT (post_id, date) DO NOTHING
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
FAILED_INSERT = f"""
|
||||||
|
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
|
||||||
|
VALUES (%(raw_line)s, %(error)s)
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
|
||||||
|
def get_psycopg_connection() -> psycopg.Connection:
|
||||||
|
"""Create a raw psycopg3 connection from environment variables."""
|
||||||
|
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
|
||||||
|
return psycopg.connect(
|
||||||
|
dbname=database,
|
||||||
|
host=host,
|
||||||
|
port=int(port),
|
||||||
|
user=username,
|
||||||
|
password=password,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_file(path: Path, *, batch_size: int) -> None:
|
||||||
|
"""Ingest a single JSONL file into the posts table."""
|
||||||
|
log_trigger = max(100_000 // batch_size, 1)
|
||||||
|
failed_lines: list[dict] = []
|
||||||
|
try:
|
||||||
|
with get_psycopg_connection() as connection:
|
||||||
|
for index, batch in enumerate(
|
||||||
|
read_jsonl_batches(path, batch_size, failed_lines), 1
|
||||||
|
):
|
||||||
|
ingest_batch(connection, batch)
|
||||||
|
if index % log_trigger == 0:
|
||||||
|
logger.info(
|
||||||
|
"Ingested %d batches (%d rows) from %s",
|
||||||
|
index,
|
||||||
|
index * batch_size,
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_lines:
|
||||||
|
logger.warning(
|
||||||
|
"Recording %d malformed lines from %s", len(failed_lines), path.name
|
||||||
|
)
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.executemany(FAILED_INSERT, failed_lines)
|
||||||
|
connection.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to ingest file: %s", path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
|
||||||
|
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(f"""
|
||||||
|
CREATE TEMP TABLE IF NOT EXISTS staging
|
||||||
|
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
|
||||||
|
ON COMMIT DELETE ROWS
|
||||||
|
""")
|
||||||
|
cursor.execute("TRUNCATE pg_temp.staging")
|
||||||
|
|
||||||
|
with cursor.copy(
|
||||||
|
f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN"
|
||||||
|
) as copy:
|
||||||
|
for row in batch:
|
||||||
|
copy.write_row(tuple(row.get(column) for column in COLUMNS))
|
||||||
|
|
||||||
|
cursor.execute(INSERT_FROM_STAGING)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as error:
|
||||||
|
connection.rollback()
|
||||||
|
|
||||||
|
if len(batch) == 1:
|
||||||
|
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
FAILED_INSERT,
|
||||||
|
{
|
||||||
|
"raw_line": orjson.dumps(batch[0], default=str).decode(),
|
||||||
|
"error": str(error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
midpoint = len(batch) // 2
|
||||||
|
ingest_batch(connection, batch[:midpoint])
|
||||||
|
ingest_batch(connection, batch[midpoint:])
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl_batches(
|
||||||
|
file_path: Path, batch_size: int, failed_lines: list[dict]
|
||||||
|
) -> Iterator[list[dict]]:
|
||||||
|
"""Stream a JSONL file and yield batches of transformed rows."""
|
||||||
|
batch: list[dict] = []
|
||||||
|
with file_path.open("r", encoding="utf-8") as handle:
|
||||||
|
for raw_line in handle:
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
batch.extend(parse_line(line, file_path, failed_lines))
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
|
||||||
|
"""Parse a JSONL line, handling concatenated JSON objects."""
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(line))
|
||||||
|
except orjson.JSONDecodeError:
|
||||||
|
if "}{" not in line:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping malformed line in %s: %s", file_path.name, line[:120]
|
||||||
|
)
|
||||||
|
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
|
||||||
|
return
|
||||||
|
fragments = line.replace("}{", "}\n{").split("\n")
|
||||||
|
for fragment in fragments:
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(fragment))
|
||||||
|
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping malformed fragment in %s: %s",
|
||||||
|
file_path.name,
|
||||||
|
fragment[:120],
|
||||||
|
)
|
||||||
|
failed_lines.append({"raw_line": fragment, "error": str(error)})
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
|
||||||
|
failed_lines.append({"raw_line": line, "error": str(error)})
|
||||||
|
|
||||||
|
|
||||||
|
def transform_row(raw: dict) -> dict:
|
||||||
|
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
|
||||||
|
raw["date"] = parse_date(raw["date"])
|
||||||
|
if raw.get("langs") is not None:
|
||||||
|
raw["langs"] = orjson.dumps(raw["langs"])
|
||||||
|
if raw.get("text") is not None:
|
||||||
|
raw["text"] = raw["text"].replace("\x00", "")
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def parse_date(raw_date: int) -> datetime:
|
||||||
|
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
|
||||||
|
return datetime(
|
||||||
|
raw_date // 100000000,
|
||||||
|
(raw_date // 1000000) % 100,
|
||||||
|
(raw_date // 10000) % 100,
|
||||||
|
(raw_date // 100) % 100,
|
||||||
|
raw_date % 100,
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
@@ -17,6 +17,10 @@ NAMING_CONVENTION = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSetupError(RuntimeError):
|
||||||
|
"""Raised when database configuration is missing or invalid."""
|
||||||
|
|
||||||
|
|
||||||
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
||||||
"""Get connection info from environment variables."""
|
"""Get connection info from environment variables."""
|
||||||
database = getenv(f"{name}_DB")
|
database = getenv(f"{name}_DB")
|
||||||
@@ -27,11 +31,18 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
|||||||
|
|
||||||
if None in (database, host, port, username):
|
if None in (database, host, port, username):
|
||||||
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
||||||
raise ValueError(error)
|
raise DatabaseSetupError(error)
|
||||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
return cast(
|
||||||
|
"tuple[str, str, str, str, str | None]",
|
||||||
|
(database, host, port, username, password),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
def get_postgres_engine(
|
||||||
|
*,
|
||||||
|
name: str = "POSTGRES",
|
||||||
|
pool_pre_ping: bool = True,
|
||||||
|
) -> Engine:
|
||||||
"""Create a SQLAlchemy engine from environment variables."""
|
"""Create a SQLAlchemy engine from environment variables."""
|
||||||
database, host, port, username, password = get_connection_info(name)
|
database, host, port, username, password = get_connection_info(name)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
"""init."""
|
"""Congress ORM models."""
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.congress.bill import Bill, BillText
|
from pipelines.orm.data_science_dev.congress.bill import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.congress.legislator import (
|
from pipelines.orm.data_science_dev.congress.legislator import (
|
||||||
Legislator,
|
Legislator,
|
||||||
|
LegislatorScore,
|
||||||
LegislatorSocialMedia,
|
LegislatorSocialMedia,
|
||||||
)
|
)
|
||||||
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
||||||
@@ -10,7 +16,10 @@ from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Bill",
|
"Bill",
|
||||||
"BillText",
|
"BillText",
|
||||||
|
"BillTopic",
|
||||||
|
"BillTopicPosition",
|
||||||
"Legislator",
|
"Legislator",
|
||||||
|
"LegislatorScore",
|
||||||
"LegislatorSocialMedia",
|
"LegislatorSocialMedia",
|
||||||
"Vote",
|
"Vote",
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date
|
from datetime import date, datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
from sqlalchemy import DateTime, Enum, ForeignKey, Index, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
@@ -14,10 +15,23 @@ if TYPE_CHECKING:
|
|||||||
from pipelines.orm.data_science_dev.congress.vote import Vote
|
from pipelines.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
class BillTopicPosition(StrEnum):
|
||||||
|
"""Whether a yes vote on a bill is for or against a topic."""
|
||||||
|
|
||||||
|
FOR = "for"
|
||||||
|
AGAINST = "against"
|
||||||
|
|
||||||
|
|
||||||
class Bill(DataScienceDevTableBase):
|
class Bill(DataScienceDevTableBase):
|
||||||
"""Legislation with congress number, type, titles, status, and sponsor."""
|
"""Legislation with congress number, type, titles, status, and sponsor."""
|
||||||
|
|
||||||
__tablename__ = "bill"
|
__tablename__ = "bill"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
||||||
|
),
|
||||||
|
Index("ix_bill_congress", "congress"),
|
||||||
|
)
|
||||||
|
|
||||||
congress: Mapped[int]
|
congress: Mapped[int]
|
||||||
bill_type: Mapped[str]
|
bill_type: Mapped[str]
|
||||||
@@ -33,6 +47,7 @@ class Bill(DataScienceDevTableBase):
|
|||||||
sponsor_bioguide_id: Mapped[str | None]
|
sponsor_bioguide_id: Mapped[str | None]
|
||||||
|
|
||||||
subjects_top_term: Mapped[str | None]
|
subjects_top_term: Mapped[str | None]
|
||||||
|
score_processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
|
||||||
votes: Mapped[list[Vote]] = relationship(
|
votes: Mapped[list[Vote]] = relationship(
|
||||||
"Vote",
|
"Vote",
|
||||||
@@ -43,12 +58,10 @@ class Bill(DataScienceDevTableBase):
|
|||||||
back_populates="bill",
|
back_populates="bill",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
topics: Mapped[list[BillTopic]] = relationship(
|
||||||
__table_args__ = (
|
"BillTopic",
|
||||||
UniqueConstraint(
|
back_populates="bill",
|
||||||
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
cascade="all, delete-orphan",
|
||||||
),
|
|
||||||
Index("ix_bill_congress", "congress"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,17 +69,49 @@ class BillText(DataScienceDevTableBase):
|
|||||||
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
||||||
|
|
||||||
__tablename__ = "bill_text"
|
__tablename__ = "bill_text"
|
||||||
|
|
||||||
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
|
||||||
version_code: Mapped[str]
|
|
||||||
version_name: Mapped[str | None]
|
|
||||||
text_content: Mapped[str | None]
|
|
||||||
date: Mapped[date | None]
|
|
||||||
|
|
||||||
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
version_code: Mapped[str]
|
||||||
|
version_name: Mapped[str | None]
|
||||||
|
text_content: Mapped[str | None]
|
||||||
|
summary: Mapped[str | None]
|
||||||
|
summarization_model: Mapped[str | None]
|
||||||
|
summarization_user_prompt_version: Mapped[str | None]
|
||||||
|
summarization_system_prompt_version: Mapped[str | None]
|
||||||
|
date: Mapped[date | None]
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
||||||
|
|
||||||
|
# suport multipu summary prer bill
|
||||||
|
|
||||||
|
class BillTopic(DataScienceDevTableBase):
|
||||||
|
"""One bill stance on one topic used to score roll-call votes."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_topic"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"bill_id",
|
||||||
|
"topic",
|
||||||
|
"support_position",
|
||||||
|
name="uq_bill_topic_bill_id_topic_support_position",
|
||||||
|
),
|
||||||
|
Index("ix_bill_topic_topic", "topic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
topic: Mapped[str]
|
||||||
|
support_position: Mapped[BillTopicPosition] = mapped_column(
|
||||||
|
Enum(
|
||||||
|
BillTopicPosition,
|
||||||
|
values_callable=lambda enum_cls: [member.value for member in enum_cls],
|
||||||
|
native_enum=False,
|
||||||
|
name="bill_topic_position",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="topics")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Text
|
from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
@@ -50,6 +50,11 @@ class Legislator(DataScienceDevTableBase):
|
|||||||
back_populates="legislator",
|
back_populates="legislator",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
scores: Mapped[list[LegislatorScore]] = relationship(
|
||||||
|
"LegislatorScore",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LegislatorSocialMedia(DataScienceDevTableBase):
|
class LegislatorSocialMedia(DataScienceDevTableBase):
|
||||||
@@ -66,3 +71,28 @@ class LegislatorSocialMedia(DataScienceDevTableBase):
|
|||||||
legislator: Mapped[Legislator] = relationship(
|
legislator: Mapped[Legislator] = relationship(
|
||||||
back_populates="social_media_accounts"
|
back_populates="social_media_accounts"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LegislatorScore(DataScienceDevTableBase):
|
||||||
|
"""Computed topic score for a legislator in one calendar year."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator_score"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"legislator_id",
|
||||||
|
"year",
|
||||||
|
"topic",
|
||||||
|
name="uq_legislator_score_legislator_id_year_topic",
|
||||||
|
),
|
||||||
|
Index("ix_legislator_score_year_topic", "year", "topic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
legislator_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
year: Mapped[int]
|
||||||
|
topic: Mapped[str]
|
||||||
|
score: Mapped[float]
|
||||||
|
|
||||||
|
legislator: Mapped[Legislator] = relationship(back_populates="scores")
|
||||||
|
|||||||
@@ -44,6 +44,17 @@ class Vote(DataScienceDevTableBase):
|
|||||||
"""Roll call votes with counts and optional bill linkage."""
|
"""Roll call votes with counts and optional bill linkage."""
|
||||||
|
|
||||||
__tablename__ = "vote"
|
__tablename__ = "vote"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session",
|
||||||
|
"number",
|
||||||
|
name="uq_vote_congress_chamber_session_number",
|
||||||
|
),
|
||||||
|
Index("ix_vote_date", "vote_date"),
|
||||||
|
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||||
|
)
|
||||||
|
|
||||||
congress: Mapped[int]
|
congress: Mapped[int]
|
||||||
chamber: Mapped[str]
|
chamber: Mapped[str]
|
||||||
@@ -71,14 +82,3 @@ class Vote(DataScienceDevTableBase):
|
|||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"congress",
|
|
||||||
"chamber",
|
|
||||||
"session",
|
|
||||||
"number",
|
|
||||||
name="uq_vote_congress_chamber_session_number",
|
|
||||||
),
|
|
||||||
Index("ix_vote_date", "vote_date"),
|
|
||||||
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -2,14 +2,26 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
Legislator,
|
||||||
|
LegislatorScore,
|
||||||
|
Vote,
|
||||||
|
VoteRecord,
|
||||||
|
)
|
||||||
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
||||||
from pipelines.orm.data_science_dev.posts.tables import Posts
|
from pipelines.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Bill",
|
"Bill",
|
||||||
"BillText",
|
"BillText",
|
||||||
|
"BillTopic",
|
||||||
|
"BillTopicPosition",
|
||||||
"Legislator",
|
"Legislator",
|
||||||
|
"LegislatorScore",
|
||||||
"Posts",
|
"Posts",
|
||||||
"Vote",
|
"Vote",
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
|
|||||||
155
pipelines/parallelize.py
Normal file
155
pipelines/parallelize.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Thing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
modes = Literal["normal", "early_error"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutorResults[R]:
|
||||||
|
"""Dataclass to store the results and exceptions of the parallel execution."""
|
||||||
|
|
||||||
|
results: list[R]
|
||||||
|
exceptions: list[BaseException]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return a string representation of the object."""
|
||||||
|
return f"results={self.results} exceptions={self.exceptions}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parallelize_base[R](
|
||||||
|
executor_type: type[ThreadPoolExecutor | ProcessPoolExecutor],
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None,
|
||||||
|
progress_tracker: int | None,
|
||||||
|
mode: modes,
|
||||||
|
) -> ExecutorResults:
|
||||||
|
total_work = len(kwargs_list)
|
||||||
|
|
||||||
|
with executor_type(max_workers=max_workers) as executor:
|
||||||
|
futures = [executor.submit(func, **kwarg) for kwarg in kwargs_list]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
exceptions = []
|
||||||
|
for index, future in enumerate(futures, 1):
|
||||||
|
if exception := future.exception():
|
||||||
|
logger.error(f"{future} raised {exception.__class__.__name__}")
|
||||||
|
exceptions.append(exception)
|
||||||
|
if mode == "early_error":
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
raise exception
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(future.result())
|
||||||
|
|
||||||
|
if progress_tracker and index % progress_tracker == 0:
|
||||||
|
logger.info(f"Progress: {index}/{total_work}")
|
||||||
|
|
||||||
|
return ExecutorResults(results, exceptions)
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_thread[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None = None,
|
||||||
|
progress_tracker: int | None = None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in threads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in threads.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
return _parallelize_base(
|
||||||
|
executor_type=ThreadPoolExecutor,
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_process[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None = None,
|
||||||
|
progress_tracker: int | None = None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in process.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 4.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
if max_workers and max_workers > cpu_count():
|
||||||
|
error = f"max_workers must be less than or equal to {cpu_count()}"
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
return process_executor_unchecked(
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_executor_unchecked[R](
|
||||||
|
func: Callable[..., R],
|
||||||
|
kwargs_list: Sequence[Mapping[str, Any]],
|
||||||
|
max_workers: int | None,
|
||||||
|
progress_tracker: int | None,
|
||||||
|
mode: modes = "normal",
|
||||||
|
) -> ExecutorResults:
|
||||||
|
"""Generic function to run a function with multiple arguments in parallel.
|
||||||
|
|
||||||
|
Note: this function does not check if the number of workers is greater than the number of CPUs.
|
||||||
|
This can cause the system to become unresponsive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., R]): Function to run in parallel.
|
||||||
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
||||||
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
||||||
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
||||||
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
||||||
|
"""
|
||||||
|
return _parallelize_base(
|
||||||
|
executor_type=ProcessPoolExecutor,
|
||||||
|
func=func,
|
||||||
|
kwargs_list=kwargs_list,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
@@ -23,14 +23,10 @@ import httpx
|
|||||||
import typer
|
import typer
|
||||||
from tiktoken import Encoding, get_encoding
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = (
|
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
Path(__file__).resolve().parents[2]
|
|
||||||
/ "config"
|
|
||||||
/ "prompts"
|
|
||||||
/ "summarization_prompts.toml"
|
|
||||||
)
|
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -24,14 +24,10 @@ from typing import Annotated
|
|||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = (
|
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
Path(__file__).resolve().parents[2]
|
|
||||||
/ "config"
|
|
||||||
/ "prompts"
|
|
||||||
/ "summarization_prompts.toml"
|
|
||||||
)
|
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
from pipelines.config import default_config_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -123,7 +125,7 @@ def main(
|
|||||||
config_path: Annotated[
|
config_path: Annotated[
|
||||||
Path,
|
Path,
|
||||||
typer.Option("--config", help="TOML config file"),
|
typer.Option("--config", help="TOML config file"),
|
||||||
] = Path(__file__).parent / "config.toml",
|
] = default_config_path(),
|
||||||
save_gguf: Annotated[
|
save_gguf: Annotated[
|
||||||
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
||||||
] = False,
|
] = False,
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
|
|
||||||
|
|
||||||
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
|
|
||||||
|
|
||||||
EXTRACTION RULES:
|
|
||||||
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
|
|
||||||
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
|
|
||||||
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
|
|
||||||
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
|
|
||||||
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
|
|
||||||
|
|
||||||
OUTPUT FORMAT — plain structured text, not JSON:
|
|
||||||
|
|
||||||
OPERATIVE ACTIONS:
|
|
||||||
[Numbered list of what the bill actually does, one action per line, max 20 words each]
|
|
||||||
|
|
||||||
AFFECTED POPULATIONS:
|
|
||||||
[Who gains something, who loses something, or whose behavior is regulated]
|
|
||||||
|
|
||||||
MECHANISMS:
|
|
||||||
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
|
|
||||||
|
|
||||||
POLICY THREADS:
|
|
||||||
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
|
|
||||||
|
|
||||||
SYMBOLIC/PROCEDURAL ONLY:
|
|
||||||
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
|
|
||||||
|
|
||||||
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
|
|
||||||
|
|
||||||
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
|
|
||||||
|
|
||||||
BILL TEXT:
|
|
||||||
{text_content}"""
|
|
||||||
268
pipelines/tools/summarize_bills.py
Normal file
268
pipelines/tools/summarize_bills.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import Select, or_, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import Bill, BillText
|
||||||
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _find_prompts_path() -> Path:
|
||||||
|
return get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_summarization_prompts(
|
||||||
|
section: str = "summarization",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return tomllib.loads(_find_prompts_path().read_text())[section]
|
||||||
|
|
||||||
|
|
||||||
|
class BillSummaryError(RuntimeError):
|
||||||
|
"""Raised when a bill summary request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_summary(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
msg = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def build_bill_summary_messages(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build the GPT prompt messages for one bill text row."""
|
||||||
|
if not bill_text.text_content:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
bill = bill_text.bill
|
||||||
|
compressed_text = compress_bill_text(bill_text.text_content)
|
||||||
|
if not compressed_text:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
metadata = "\n".join(
|
||||||
|
(
|
||||||
|
f"Congress: {bill.congress}",
|
||||||
|
f"Bill: {bill.bill_type} {bill.number}",
|
||||||
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||||
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||||
|
f"Text version: {bill_text.version_code}"
|
||||||
|
+ (f" ({bill_text.version_name})" if bill_text.version_name else ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
summarization_prompts = load_summarization_prompts()
|
||||||
|
user_prompt = "\n\n".join(
|
||||||
|
(
|
||||||
|
"BILL METADATA:",
|
||||||
|
metadata,
|
||||||
|
summarization_prompts["user_template"].format(text_content=compressed_text),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bill_text(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
bill_text: BillText,
|
||||||
|
) -> str:
|
||||||
|
"""Generate and return a summary for one bill_text row."""
|
||||||
|
messages = build_bill_summary_messages(bill_text=bill_text)
|
||||||
|
summary = call_openai_summary(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
).strip()
|
||||||
|
if not summary:
|
||||||
|
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_summary_result(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summary: str,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||||
|
bill_text.summary = summary
|
||||||
|
bill_text.summarization_model = model
|
||||||
|
bill_text.summarization_system_prompt_version = "v1"
|
||||||
|
bill_text.summarization_user_prompt_version = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bill_texts_for_summarization(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Select bill_text rows that have source text and need summaries."""
|
||||||
|
stmt = (
|
||||||
|
select(BillText)
|
||||||
|
.join(Bill, Bill.id == BillText.bill_id)
|
||||||
|
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||||
|
.options(selectinload(BillText.bill))
|
||||||
|
.order_by(BillText.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill summaries."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
if not getenv("CLOSEDAI_TOKEN"):
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = create_select_bill_texts_for_summarization(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bill_texts = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for index, bill_text in enumerate(bill_texts, 1):
|
||||||
|
summary = summarize_bill_text(
|
||||||
|
model=model,
|
||||||
|
bill_text=bill_text,
|
||||||
|
)
|
||||||
|
store_bill_summary_result(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summary=summary,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
written += 1
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||||
|
|
||||||
|
logger.info("Done: stored %d summaries", written)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[project]
|
||||||
|
name = "ds-testing-pipelines"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Data science pipeline tools and legislative dashboard."
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi",
|
||||||
|
"httpx",
|
||||||
|
"uvicorn[standard]",
|
||||||
|
"jinja2",
|
||||||
|
"sqlalchemy",
|
||||||
|
"psycopg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
Reference in New Issue
Block a user