Compare commits

..

46 Commits

Author SHA1 Message Date
ac02d407eb adding math to bob 2026-04-11 19:40:26 -04:00
9a77eda471 added config.toml to git ignore 2026-04-10 22:05:15 -04:00
26105b7daa updated BenchmarkConfig to have from_toml 2026-04-10 21:55:18 -04:00
0d81f2d17b setup FinetuneConfig 2026-04-10 21:40:17 -04:00
1409e9c63e deleted train.sh 2026-04-10 20:58:26 -04:00
259e952afc added containers dir 2026-04-10 20:48:24 -04:00
4a10a80ba0 conveted to summarization_prompts 2026-04-10 18:57:21 -04:00
03208a1ab2 moved renamed container.py to vllm_container.py 2026-04-10 13:16:18 -04:00
721526022b created working finetuing pipeline 2026-04-10 12:56:57 -04:00
921a397b1c added data dir for traning 2026-04-10 12:51:41 -04:00
b867e809cd updated spell check 2026-04-10 12:43:24 -04:00
54eb46a63e added storage pool 2026-04-10 12:42:58 -04:00
67131e7b68 added tiktoken 2026-04-10 12:42:35 -04:00
88dae310b6 added summarization_prompts.py to sore the prompts 2026-04-10 12:40:36 -04:00
24f0e8693a added tools dir for on off scripts i used 2026-04-10 12:37:14 -04:00
ced78fe516 added batch_bill_summarizer.py
batch bill  summarizer sends a batch api call to gpt
2026-04-10 12:36:39 -04:00
d281d070a3 decreased root_pool/models snapshot life 2026-04-10 08:51:03 -04:00
251da6c14a added bill_token_compression.py
tested on sample size of 100 bills matching the distribution of our data
Compression saves ~11.5% on prompt tokens; completion/reasoning are roughly equal across the two sets.
prompt	completion	reasoning	total
compressed	349,460	157,110	112,128	506,570
uncompressed	394,948	154,710	110,080	549,658
delta	−45,488	+2,400	+2,048	−43,088
2026-04-09 18:41:13 -04:00
d17c883476 created main prompt bench 2026-04-08 09:08:25 -04:00
d358f0fbec fixed sunshine.nix 2026-04-08 00:18:34 -04:00
c150fc8612 converting bob to a server 2026-04-08 00:18:17 -04:00
9c8013d69d creating prompt_bench downloader 2026-04-07 19:15:43 -04:00
af365fce9a setup sunshine.nix 2026-04-03 17:12:24 -04:00
6430049e92 updated postgres snapshot settings 2026-03-30 14:07:08 -04:00
26e4620f8f fixed systemd sandboxing 2026-03-30 14:07:08 -04:00
93fc700fa2 removed preStart step 2026-03-30 14:07:08 -04:00
8d1c1fc628 added mountpoint= to postgres zfs create 2026-03-30 14:07:08 -04:00
dda318753b improving postgres wal 2026-03-30 14:07:08 -04:00
261ff139f7 removed ds table from richie DB 2026-03-29 15:54:54 -04:00
ba8ff35109 updated ingest_congress to use congress-legislators for legislator info 2026-03-29 15:54:54 -04:00
e368402eea adding LegislatorSocialMedia 2026-03-29 15:54:54 -04:00
dd9329d218 fixed tests 2026-03-29 15:54:54 -04:00
89f6627bed converted session.execute(select to session.scalars(select 2026-03-29 15:54:54 -04:00
c5babf8bad ran treefmt 2026-03-29 15:54:54 -04:00
dae38ffd9b added ingest_congress.py 2026-03-29 15:54:54 -04:00
ca62cc36a7 adding congress data to new DS DB 2026-03-29 15:54:54 -04:00
035410f39e adding nemotron-3-nano 2026-03-29 15:54:54 -04:00
e40ab757ca making more generic exception handling 2026-03-29 15:54:54 -04:00
345ba94a59 ran ingest_posts 2026-03-29 15:54:54 -04:00
f2084206b6 adding tables for 2023 2026-03-29 15:54:54 -04:00
50e764146a added ingest_posts.py 2026-03-29 15:54:54 -04:00
ea97b5eb19 adding 2026 partitions 2026-03-29 15:54:54 -04:00
1ef2512daa adding post table 2026-03-29 15:54:54 -04:00
f9a9e5395c added media/temp for fast dir when working with data 2026-03-29 15:54:54 -04:00
d8e166a340 adding data_science_dev 2026-03-29 15:54:54 -04:00
c266ba79f4 updated snapshot_config.toml 2026-03-29 14:12:06 -04:00
68 changed files with 9374 additions and 340 deletions

4
.gitignore vendored
View File

@@ -169,3 +169,7 @@ test.*
# Frontend build output
frontend/dist/
frontend/node_modules/
# data dir for training, validation, and testing
data/
config.toml

View File

@@ -308,6 +308,7 @@
"usernamehw",
"userprefs",
"vaninventory",
"vdev",
"vfat",
"victron",
"virt",

View File

@@ -1,26 +1 @@
# dotfiles
<!-- LINE-COUNT-START -->
This repo has **20,055** lines of technical debt.
| File Type | Lines | Percentage |
|-----------|------:|-----------:|
| .py | 11,441 | 57.0% |
| .nix | 4,471 | 22.3% |
| .yaml | 1,121 | 5.6% |
| .html | 1,009 | 5.0% |
| .json | 555 | 2.8% |
| .yml | 479 | 2.4% |
| .toml | 290 | 1.4% |
| .css | 212 | 1.1% |
| .gitignore | 199 | 1.0% |
| .md | 75 | 0.4% |
| .cfg | 73 | 0.4% |
| .sh | 48 | 0.2% |
| .mako | 36 | 0.2% |
| .LICENSE | 21 | 0.1% |
| .conf | 17 | 0.1% |
| .Gemfile | 4 | 0.0% |
| .svg | 3 | 0.0% |
| .new | 1 | 0.0% |
<!-- LINE-COUNT-END -->

View File

@@ -24,7 +24,9 @@
fastapi
fastapi-cli
httpx
huggingface-hub
mypy
orjson
polars
psycopg
pydantic
@@ -40,6 +42,7 @@
sqlalchemy
tenacity
textual
tiktoken
tinytuya
typer
websockets

View File

@@ -12,6 +12,7 @@ dependencies = [
"alembic",
"apprise",
"apscheduler",
"huggingface-hub",
"httpx",
"python-multipart",
"polars",
@@ -26,6 +27,11 @@ dependencies = [
[project.scripts]
database = "python.database_cli:app"
van-inventory = "python.van_inventory.main:serve"
prompt-bench = "python.prompt_bench.main:cli"
prompt-bench-download = "python.prompt_bench.downloader:cli"
finetune = "python.prompt_bench.finetune:cli"
finetune-container = "python.prompt_bench.finetune_container:cli"
build-finetune-dataset = "python.prompt_bench.build_finetune_dataset:cli"
[dependency-groups]
dev = [
@@ -81,6 +87,11 @@ lint.ignore = [
"python/eval_warnings/**" = [
"S607", # (perm) gh and git are expected on PATH in the runner environment
]
"python/prompt_bench/**" = [
"FBT002", # (perm) typer requires boolean defaults for --flag/--no-flag options
"PLR0913", # (perm) typer CLIs naturally have many parameters
"S607", # (perm) docker and nvidia-smi are expected on PATH
]
"python/alembic/**" = [
"INP001", # (perm) this creates LSP issues for alembic
]

View File

@@ -0,0 +1,50 @@
"""adding FailedIngestion.
Revision ID: 2f43120e3ffc
Revises: f99be864fe69
Create Date: 2026-03-24 23:46:17.277897
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "2f43120e3ffc"
down_revision: str | None = "f99be864fe69"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"failed_ingestion",
sa.Column("raw_line", sa.Text(), nullable=False),
sa.Column("error", sa.Text(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("failed_ingestion", schema=schema)
# ### end Alembic commands ###

View File

@@ -0,0 +1,72 @@
"""Attach all partition tables to the posts parent table.
Alembic autogenerate creates partition tables as standalone tables but does not
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
route inserts to the correct partition.
Revision ID: a1b2c3d4e5f6
Revises: 605b1794838f
Create Date: 2026-03-25 10:00:00.000000
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from alembic import op
from sqlalchemy import text
from python.orm import DataScienceDevBase
from python.orm.data_science_dev.posts.partitions import (
PARTITION_END_YEAR,
PARTITION_START_YEAR,
iso_weeks_in_year,
week_bounds,
)
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = "605b1794838f"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
ALREADY_ATTACHED_QUERY = text("""
SELECT inhrelid::regclass::text
FROM pg_inherits
WHERE inhparent = :parent::regclass
""")
def upgrade() -> None:
"""Attach all weekly partition tables to the posts parent table."""
connection = op.get_bind()
already_attached = {row[0] for row in connection.execute(ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"})}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
table_name = f"posts_{year}_{week:02d}"
qualified_name = f"{schema}.{table_name}"
if qualified_name in already_attached:
continue
start, end = week_bounds(year, week)
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
op.execute(
f"ALTER TABLE {schema}.posts "
f"ATTACH PARTITION {qualified_name} "
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
)
def downgrade() -> None:
"""Detach all weekly partition tables from the posts parent table."""
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
table_name = f"posts_{year}_{week:02d}"
op.execute(f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}")

View File

@@ -0,0 +1,153 @@
"""adding congress data.
Revision ID: 83bfc8af92d8
Revises: a1b2c3d4e5f6
Create Date: 2026-03-27 10:43:02.324510
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "83bfc8af92d8"
down_revision: str | None = "a1b2c3d4e5f6"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"bill",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("bill_type", sa.String(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("title_short", sa.String(), nullable=True),
sa.Column("official_title", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=True),
sa.Column("status_at", sa.Date(), nullable=True),
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
sa.Column("subjects_top_term", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
schema=schema,
)
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
op.create_table(
"legislator",
sa.Column("bioguide_id", sa.Text(), nullable=False),
sa.Column("thomas_id", sa.String(), nullable=True),
sa.Column("lis_id", sa.String(), nullable=True),
sa.Column("govtrack_id", sa.Integer(), nullable=True),
sa.Column("opensecrets_id", sa.String(), nullable=True),
sa.Column("fec_ids", sa.String(), nullable=True),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=False),
sa.Column("official_full_name", sa.String(), nullable=True),
sa.Column("nickname", sa.String(), nullable=True),
sa.Column("birthday", sa.Date(), nullable=True),
sa.Column("gender", sa.String(), nullable=True),
sa.Column("current_party", sa.String(), nullable=True),
sa.Column("current_state", sa.String(), nullable=True),
sa.Column("current_district", sa.Integer(), nullable=True),
sa.Column("current_chamber", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
schema=schema,
)
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
op.create_table(
"bill_text",
sa.Column("bill_id", sa.Integer(), nullable=False),
sa.Column("version_code", sa.String(), nullable=False),
sa.Column("version_name", sa.String(), nullable=True),
sa.Column("text_content", sa.String(), nullable=True),
sa.Column("date", sa.Date(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_bill_text_bill_id_bill"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
sa.UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),
schema=schema,
)
op.create_table(
"vote",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("chamber", sa.String(), nullable=False),
sa.Column("session", sa.Integer(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("vote_type", sa.String(), nullable=True),
sa.Column("question", sa.String(), nullable=True),
sa.Column("result", sa.String(), nullable=True),
sa.Column("result_text", sa.String(), nullable=True),
sa.Column("vote_date", sa.Date(), nullable=False),
sa.Column("yea_count", sa.Integer(), nullable=True),
sa.Column("nay_count", sa.Integer(), nullable=True),
sa.Column("not_voting_count", sa.Integer(), nullable=True),
sa.Column("present_count", sa.Integer(), nullable=True),
sa.Column("bill_id", sa.Integer(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
schema=schema,
)
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
op.create_table(
"vote_record",
sa.Column("vote_id", sa.Integer(), nullable=False),
sa.Column("legislator_id", sa.Integer(), nullable=False),
sa.Column("position", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_vote_record_legislator_id_legislator"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("vote_record", schema=schema)
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
op.drop_table("vote", schema=schema)
op.drop_table("bill_text", schema=schema)
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
op.drop_table("legislator", schema=schema)
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
op.drop_table("bill", schema=schema)
# ### end Alembic commands ###

View File

@@ -0,0 +1,58 @@
"""adding LegislatorSocialMedia.
Revision ID: 5cd7eee3549d
Revises: 83bfc8af92d8
Create Date: 2026-03-29 11:53:44.224799
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "5cd7eee3549d"
down_revision: str | None = "83bfc8af92d8"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"legislator_social_media",
sa.Column("legislator_id", sa.Integer(), nullable=False),
sa.Column("platform", sa.String(), nullable=False),
sa.Column("account_name", sa.String(), nullable=False),
sa.Column("url", sa.String(), nullable=True),
sa.Column("source", sa.String(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("legislator_social_media", schema=schema)
# ### end Alembic commands ###

View File

@@ -81,6 +81,7 @@ def include_name(
"""
if type_ == "schema":
# allows a database with multiple schemas to have separate alembic revisions
return name == target_metadata.schema
return True

View File

@@ -0,0 +1,187 @@
"""removed ds table from richie DB.
Revision ID: c8a794340928
Revises: 6b275323f435
Create Date: 2026-03-29 15:29:23.643146
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "c8a794340928"
down_revision: str | None = "6b275323f435"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("vote_record", schema=schema)
op.drop_index(op.f("ix_vote_congress_chamber"), table_name="vote", schema=schema)
op.drop_index(op.f("ix_vote_date"), table_name="vote", schema=schema)
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
op.drop_table("legislator", schema=schema)
op.drop_table("vote", schema=schema)
op.drop_index(op.f("ix_bill_congress"), table_name="bill", schema=schema)
op.drop_table("bill", schema=schema)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"vote",
sa.Column("congress", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("chamber", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("session", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("number", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("vote_type", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("question", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("result", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("result_text", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("vote_date", sa.DATE(), autoincrement=False, nullable=False),
sa.Column("yea_count", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("nay_count", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("not_voting_count", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("present_count", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("bill_id", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column(
"created",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.Column(
"updated",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
sa.UniqueConstraint(
"congress",
"chamber",
"session",
"number",
name=op.f("uq_vote_congress_chamber_session_number"),
postgresql_include=[],
postgresql_nulls_not_distinct=False,
),
schema=schema,
)
op.create_index(op.f("ix_vote_date"), "vote", ["vote_date"], unique=False, schema=schema)
op.create_index(op.f("ix_vote_congress_chamber"), "vote", ["congress", "chamber"], unique=False, schema=schema)
op.create_table(
"vote_record",
sa.Column("vote_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("legislator_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("position", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_vote_record_legislator_id_legislator"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
schema=schema,
)
op.create_table(
"legislator",
sa.Column("bioguide_id", sa.TEXT(), autoincrement=False, nullable=False),
sa.Column("thomas_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("lis_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("govtrack_id", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("opensecrets_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("fec_ids", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("first_name", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("last_name", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("official_full_name", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("nickname", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("birthday", sa.DATE(), autoincrement=False, nullable=True),
sa.Column("gender", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("current_party", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("current_state", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("current_district", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("current_chamber", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column(
"created",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.Column(
"updated",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
schema=schema,
)
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
op.create_table(
"bill",
sa.Column("congress", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("bill_type", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("number", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("title", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("title_short", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("official_title", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("status", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("status_at", sa.DATE(), autoincrement=False, nullable=True),
sa.Column("sponsor_bioguide_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("subjects_top_term", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column(
"created",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.Column(
"updated",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
sa.UniqueConstraint(
"congress",
"bill_type",
"number",
name=op.f("uq_bill_congress_type_number"),
postgresql_include=[],
postgresql_nulls_not_distinct=False,
),
schema=schema,
)
op.create_index(op.f("ix_bill_congress"), "bill", ["congress"], unique=False, schema=schema)
# ### end Alembic commands ###

View File

@@ -0,0 +1,3 @@
"""Data science CLI tools."""
from __future__ import annotations

View File

@@ -0,0 +1,613 @@
"""Ingestion pipeline for loading congress data from unitedstates/congress JSON files.
Loads legislators, bills, votes, vote records, and bill text into the data_science_dev database.
Expects the parent directory to contain congress-tracker/ and congress-legislators/ as siblings.
Usage:
ingest-congress /path/to/parent/
ingest-congress /path/to/parent/ --congress 118
ingest-congress /path/to/parent/ --congress 118 --only bills
"""
from __future__ import annotations
import logging
from pathlib import Path # noqa: TC003 needed at runtime for typer CLI argument
from typing import TYPE_CHECKING, Annotated
import orjson
import typer
import yaml
from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, LegislatorSocialMedia, Vote, VoteRecord
if TYPE_CHECKING:
from collections.abc import Iterator
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
BATCH_SIZE = 10_000
app = typer.Typer(help="Ingest unitedstates/congress data into data_science_dev.")
@app.command()
def main(
parent_dir: Annotated[
Path,
typer.Argument(help="Parent directory containing congress-tracker/ and congress-legislators/"),
],
congress: Annotated[int | None, typer.Option(help="Only ingest a specific congress number")] = None,
only: Annotated[
str | None,
typer.Option(help="Only run a specific step: legislators, social-media, bills, votes, bill-text"),
] = None,
) -> None:
"""Ingest congress data from unitedstates/congress JSON files."""
configure_logger(level="INFO")
data_dir = parent_dir / "congress-tracker/congress/data/"
legislators_dir = parent_dir / "congress-legislators"
if not data_dir.is_dir():
typer.echo(f"Expected congress-tracker/ directory: {data_dir}", err=True)
raise typer.Exit(code=1)
if not legislators_dir.is_dir():
typer.echo(f"Expected congress-legislators/ directory: {legislators_dir}", err=True)
raise typer.Exit(code=1)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
congress_dirs = _resolve_congress_dirs(data_dir, congress)
if not congress_dirs:
typer.echo("No congress directories found.", err=True)
raise typer.Exit(code=1)
logger.info("Found %d congress directories to process", len(congress_dirs))
steps: dict[str, tuple] = {
"legislators": (ingest_legislators, (engine, legislators_dir)),
"legislators-social-media": (ingest_social_media, (engine, legislators_dir)),
"bills": (ingest_bills, (engine, congress_dirs)),
"votes": (ingest_votes, (engine, congress_dirs)),
"bill-text": (ingest_bill_text, (engine, congress_dirs)),
}
if only:
if only not in steps:
typer.echo(f"Unknown step: {only}. Choose from: {', '.join(steps)}", err=True)
raise typer.Exit(code=1)
steps = {only: steps[only]}
for step_name, (step_func, step_args) in steps.items():
logger.info("=== Starting step: %s ===", step_name)
step_func(*step_args)
logger.info("=== Finished step: %s ===", step_name)
logger.info("ingest-congress done")
def _resolve_congress_dirs(data_dir: Path, congress: int | None) -> list[Path]:
"""Find congress number directories under data_dir."""
if congress is not None:
target = data_dir / str(congress)
return [target] if target.is_dir() else []
return sorted(path for path in data_dir.iterdir() if path.is_dir() and path.name.isdigit())
def _flush_batch(session: Session, batch: list[object], label: str) -> int:
"""Add a batch of ORM objects to the session and commit. Returns count added."""
if not batch:
return 0
session.add_all(batch)
session.commit()
count = len(batch)
logger.info("Committed %d %s", count, label)
batch.clear()
return count
# ---------------------------------------------------------------------------
# Legislators — loaded from congress-legislators YAML files
# ---------------------------------------------------------------------------
def ingest_legislators(engine: Engine, legislators_dir: Path) -> None:
"""Load legislators from congress-legislators YAML files."""
legislators_data = _load_legislators_yaml(legislators_dir)
logger.info("Loaded %d legislators from YAML files", len(legislators_data))
with Session(engine) as session:
existing_legislators = {
legislator.bioguide_id: legislator for legislator in session.scalars(select(Legislator)).all()
}
logger.info("Found %d existing legislators in DB", len(existing_legislators))
total_inserted = 0
total_updated = 0
for entry in legislators_data:
bioguide_id = entry.get("id", {}).get("bioguide")
if not bioguide_id:
continue
fields = _parse_legislator(entry)
if existing := existing_legislators.get(bioguide_id):
changed = False
for field, value in fields.items():
if value is not None and getattr(existing, field) != value:
setattr(existing, field, value)
changed = True
if changed:
total_updated += 1
else:
session.add(Legislator(bioguide_id=bioguide_id, **fields))
total_inserted += 1
session.commit()
logger.info("Inserted %d new legislators, updated %d existing", total_inserted, total_updated)
def _load_legislators_yaml(legislators_dir: Path) -> list[dict]:
"""Load and combine legislators-current.yaml and legislators-historical.yaml."""
legislators: list[dict] = []
for filename in ("legislators-current.yaml", "legislators-historical.yaml"):
path = legislators_dir / filename
if not path.exists():
logger.warning("Legislators file not found: %s", path)
continue
with path.open() as file:
data = yaml.safe_load(file)
if isinstance(data, list):
legislators.extend(data)
return legislators
def _parse_legislator(entry: dict) -> dict:
"""Extract Legislator fields from a congress-legislators YAML entry."""
ids = entry.get("id", {})
name = entry.get("name", {})
bio = entry.get("bio", {})
terms = entry.get("terms", [])
latest_term = terms[-1] if terms else {}
fec_ids = ids.get("fec")
fec_ids_joined = ",".join(fec_ids) if isinstance(fec_ids, list) else fec_ids
chamber = latest_term.get("type")
chamber_normalized = {"rep": "House", "sen": "Senate"}.get(chamber, chamber)
return {
"thomas_id": ids.get("thomas"),
"lis_id": ids.get("lis"),
"govtrack_id": ids.get("govtrack"),
"opensecrets_id": ids.get("opensecrets"),
"fec_ids": fec_ids_joined,
"first_name": name.get("first"),
"last_name": name.get("last"),
"official_full_name": name.get("official_full"),
"nickname": name.get("nickname"),
"birthday": bio.get("birthday"),
"gender": bio.get("gender"),
"current_party": latest_term.get("party"),
"current_state": latest_term.get("state"),
"current_district": latest_term.get("district"),
"current_chamber": chamber_normalized,
}
# ---------------------------------------------------------------------------
# Social Media — loaded from legislators-social-media.yaml
# ---------------------------------------------------------------------------
SOCIAL_MEDIA_PLATFORMS = {
"twitter": "https://twitter.com/{account}",
"facebook": "https://facebook.com/{account}",
"youtube": "https://youtube.com/{account}",
"instagram": "https://instagram.com/{account}",
"mastodon": None,
}
def ingest_social_media(engine: Engine, legislators_dir: Path) -> None:
"""Load social media accounts from legislators-social-media.yaml."""
social_media_path = legislators_dir / "legislators-social-media.yaml"
if not social_media_path.exists():
logger.warning("Social media file not found: %s", social_media_path)
return
with social_media_path.open() as file:
social_media_data = yaml.safe_load(file)
if not isinstance(social_media_data, list):
logger.warning("Unexpected format in %s", social_media_path)
return
logger.info("Loaded %d entries from legislators-social-media.yaml", len(social_media_data))
with Session(engine) as session:
legislator_map = _build_legislator_map(session)
existing_accounts = {
(account.legislator_id, account.platform)
for account in session.scalars(select(LegislatorSocialMedia)).all()
}
logger.info("Found %d existing social media accounts in DB", len(existing_accounts))
total_inserted = 0
total_updated = 0
for entry in social_media_data:
bioguide_id = entry.get("id", {}).get("bioguide")
if not bioguide_id:
continue
legislator_id = legislator_map.get(bioguide_id)
if legislator_id is None:
continue
social = entry.get("social", {})
for platform, url_template in SOCIAL_MEDIA_PLATFORMS.items():
account_name = social.get(platform)
if not account_name:
continue
url = url_template.format(account=account_name) if url_template else None
if (legislator_id, platform) in existing_accounts:
total_updated += 1
else:
session.add(
LegislatorSocialMedia(
legislator_id=legislator_id,
platform=platform,
account_name=str(account_name),
url=url,
source="https://github.com/unitedstates/congress-legislators",
)
)
existing_accounts.add((legislator_id, platform))
total_inserted += 1
session.commit()
logger.info("Inserted %d new social media accounts, updated %d existing", total_inserted, total_updated)
def _iter_voters(position_group: object) -> Iterator[dict]:
"""Yield voter dicts from a vote position group (handles list, single dict, or string)."""
if isinstance(position_group, dict):
yield position_group
elif isinstance(position_group, list):
for voter in position_group:
if isinstance(voter, dict):
yield voter
# ---------------------------------------------------------------------------
# Bills
# ---------------------------------------------------------------------------
def ingest_bills(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load bill data.json files."""
with Session(engine) as session:
existing_bills = {(bill.congress, bill.bill_type, bill.number) for bill in session.scalars(select(Bill)).all()}
logger.info("Found %d existing bills in DB", len(existing_bills))
total_inserted = 0
batch: list[Bill] = []
for congress_dir in congress_dirs:
bills_dir = congress_dir / "bills"
if not bills_dir.is_dir():
continue
logger.info("Scanning bills from %s", congress_dir.name)
for bill_file in bills_dir.rglob("data.json"):
data = _read_json(bill_file)
if data is None:
continue
bill = _parse_bill(data, existing_bills)
if bill is not None:
batch.append(bill)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "bills")
total_inserted += _flush_batch(session, batch, "bills")
logger.info("Inserted %d new bills total", total_inserted)
def _parse_bill(data: dict, existing_bills: set[tuple[int, str, int]]) -> Bill | None:
"""Parse a bill data.json dict into a Bill ORM object, skipping existing."""
raw_congress = data.get("congress")
bill_type = data.get("bill_type")
raw_number = data.get("number")
if raw_congress is None or bill_type is None or raw_number is None:
return None
congress = int(raw_congress)
number = int(raw_number)
if (congress, bill_type, number) in existing_bills:
return None
sponsor_bioguide = None
sponsor = data.get("sponsor")
if sponsor:
sponsor_bioguide = sponsor.get("bioguide_id")
return Bill(
congress=congress,
bill_type=bill_type,
number=number,
title=data.get("short_title") or data.get("official_title"),
title_short=data.get("short_title"),
official_title=data.get("official_title"),
status=data.get("status"),
status_at=data.get("status_at"),
sponsor_bioguide_id=sponsor_bioguide,
subjects_top_term=data.get("subjects_top_term"),
)
# ---------------------------------------------------------------------------
# Votes (and vote records)
# ---------------------------------------------------------------------------
def ingest_votes(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load vote data.json files with their vote records."""
with Session(engine) as session:
legislator_map = _build_legislator_map(session)
logger.info("Loaded %d legislators into lookup map", len(legislator_map))
bill_map = _build_bill_map(session)
logger.info("Loaded %d bills into lookup map", len(bill_map))
existing_votes = {
(vote.congress, vote.chamber, vote.session, vote.number) for vote in session.scalars(select(Vote)).all()
}
logger.info("Found %d existing votes in DB", len(existing_votes))
total_inserted = 0
batch: list[Vote] = []
for congress_dir in congress_dirs:
votes_dir = congress_dir / "votes"
if not votes_dir.is_dir():
continue
logger.info("Scanning votes from %s", congress_dir.name)
for vote_file in votes_dir.rglob("data.json"):
data = _read_json(vote_file)
if data is None:
continue
vote = _parse_vote(data, legislator_map, bill_map, existing_votes)
if vote is not None:
batch.append(vote)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "votes")
total_inserted += _flush_batch(session, batch, "votes")
logger.info("Inserted %d new votes total", total_inserted)
def _build_legislator_map(session: Session) -> dict[str, int]:
"""Build a mapping of bioguide_id -> legislator.id."""
return {legislator.bioguide_id: legislator.id for legislator in session.scalars(select(Legislator)).all()}
def _build_bill_map(session: Session) -> dict[tuple[int, str, int], int]:
"""Build a mapping of (congress, bill_type, number) -> bill.id."""
return {(bill.congress, bill.bill_type, bill.number): bill.id for bill in session.scalars(select(Bill)).all()}
def _parse_vote(
data: dict,
legislator_map: dict[str, int],
bill_map: dict[tuple[int, str, int], int],
existing_votes: set[tuple[int, str, int, int]],
) -> Vote | None:
"""Parse a vote data.json dict into a Vote ORM object with records."""
raw_congress = data.get("congress")
chamber = data.get("chamber")
raw_number = data.get("number")
vote_date = data.get("date")
if raw_congress is None or chamber is None or raw_number is None or vote_date is None:
return None
raw_session = data.get("session")
if raw_session is None:
return None
congress = int(raw_congress)
number = int(raw_number)
session_number = int(raw_session)
# Normalize chamber from "h"/"s" to "House"/"Senate"
chamber_normalized = {"h": "House", "s": "Senate"}.get(chamber, chamber)
if (congress, chamber_normalized, session_number, number) in existing_votes:
return None
# Resolve linked bill
bill_id = None
bill_ref = data.get("bill")
if bill_ref:
bill_key = (
int(bill_ref.get("congress", congress)),
bill_ref.get("type"),
int(bill_ref.get("number", 0)),
)
bill_id = bill_map.get(bill_key)
raw_votes = data.get("votes", {})
vote_counts = _count_votes(raw_votes)
vote_records = _build_vote_records(raw_votes, legislator_map)
return Vote(
congress=congress,
chamber=chamber_normalized,
session=session_number,
number=number,
vote_type=data.get("type"),
question=data.get("question"),
result=data.get("result"),
result_text=data.get("result_text"),
vote_date=vote_date[:10] if isinstance(vote_date, str) else vote_date,
bill_id=bill_id,
vote_records=vote_records,
**vote_counts,
)
def _count_votes(raw_votes: dict) -> dict[str, int]:
"""Count voters per position category, correctly handling dict and list formats."""
yea_count = 0
nay_count = 0
not_voting_count = 0
present_count = 0
for position, position_group in raw_votes.items():
voter_count = sum(1 for _ in _iter_voters(position_group))
if position in ("Yea", "Aye"):
yea_count += voter_count
elif position in ("Nay", "No"):
nay_count += voter_count
elif position == "Not Voting":
not_voting_count += voter_count
elif position == "Present":
present_count += voter_count
return {
"yea_count": yea_count,
"nay_count": nay_count,
"not_voting_count": not_voting_count,
"present_count": present_count,
}
def _build_vote_records(raw_votes: dict, legislator_map: dict[str, int]) -> list[VoteRecord]:
"""Build VoteRecord objects from raw vote data."""
records: list[VoteRecord] = []
for position, position_group in raw_votes.items():
for voter in _iter_voters(position_group):
bioguide_id = voter.get("id")
if not bioguide_id:
continue
legislator_id = legislator_map.get(bioguide_id)
if legislator_id is None:
continue
records.append(
VoteRecord(
legislator_id=legislator_id,
position=position,
)
)
return records
# ---------------------------------------------------------------------------
# Bill Text
# ---------------------------------------------------------------------------
def ingest_bill_text(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load bill text from text-versions directories."""
with Session(engine) as session:
bill_map = _build_bill_map(session)
logger.info("Loaded %d bills into lookup map", len(bill_map))
existing_bill_texts = {
(bill_text.bill_id, bill_text.version_code) for bill_text in session.scalars(select(BillText)).all()
}
logger.info("Found %d existing bill text versions in DB", len(existing_bill_texts))
total_inserted = 0
batch: list[BillText] = []
for congress_dir in congress_dirs:
logger.info("Scanning bill texts from %s", congress_dir.name)
for bill_text in _iter_bill_texts(congress_dir, bill_map, existing_bill_texts):
batch.append(bill_text)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "bill texts")
total_inserted += _flush_batch(session, batch, "bill texts")
logger.info("Inserted %d new bill text versions total", total_inserted)
def _iter_bill_texts(
congress_dir: Path,
bill_map: dict[tuple[int, str, int], int],
existing_bill_texts: set[tuple[int, str]],
) -> Iterator[BillText]:
"""Yield BillText objects for a single congress directory, skipping existing."""
bills_dir = congress_dir / "bills"
if not bills_dir.is_dir():
return
for bill_dir in bills_dir.rglob("text-versions"):
if not bill_dir.is_dir():
continue
bill_key = _bill_key_from_dir(bill_dir.parent, congress_dir)
if bill_key is None:
continue
bill_id = bill_map.get(bill_key)
if bill_id is None:
continue
for version_dir in sorted(bill_dir.iterdir()):
if not version_dir.is_dir():
continue
if (bill_id, version_dir.name) in existing_bill_texts:
continue
text_content = _read_bill_text(version_dir)
version_data = _read_json(version_dir / "data.json")
yield BillText(
bill_id=bill_id,
version_code=version_dir.name,
version_name=version_data.get("version_name") if version_data else None,
date=version_data.get("issued_on") if version_data else None,
text_content=text_content,
)
def _bill_key_from_dir(bill_dir: Path, congress_dir: Path) -> tuple[int, str, int] | None:
"""Extract (congress, bill_type, number) from directory structure."""
congress = int(congress_dir.name)
bill_type = bill_dir.parent.name
name = bill_dir.name
# Directory name is like "hr3590" — strip the type prefix to get the number
number_str = name[len(bill_type) :]
if not number_str.isdigit():
return None
return (congress, bill_type, int(number_str))
def _read_bill_text(version_dir: Path) -> str | None:
"""Read bill text from a version directory, preferring .txt over .xml."""
for extension in ("txt", "htm", "html", "xml"):
candidates = list(version_dir.glob(f"document.{extension}"))
if not candidates:
candidates = list(version_dir.glob(f"*.{extension}"))
if candidates:
try:
return candidates[0].read_text(encoding="utf-8")
except Exception:
logger.exception("Failed to read %s", candidates[0])
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _read_json(path: Path) -> dict | None:
"""Read and parse a JSON file, returning None on failure."""
try:
return orjson.loads(path.read_bytes())
except FileNotFoundError:
return None
except Exception:
logger.exception("Failed to parse %s", path)
return None
if __name__ == "__main__":
app()

View File

@@ -0,0 +1,247 @@
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
Usage:
ingest-posts /path/to/files/
ingest-posts /path/to/single_file.jsonl
ingest-posts /data/dir/ --workers 4 --batch-size 5000
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from pathlib import Path # noqa: TC003 this is needed for typer
from typing import TYPE_CHECKING, Annotated
import orjson
import psycopg
import typer
from python.common import configure_logger
from python.orm.common import get_connection_info
from python.parallelize import parallelize_process
if TYPE_CHECKING:
from collections.abc import Iterator
logger = logging.getLogger(__name__)
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
@app.command()
def main(
path: Annotated[Path, typer.Argument(help="Directory containing JSONL files, or a single JSONL file")],
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
workers: Annotated[int, typer.Option(help="Parallel workers for multi-file ingestion")] = 4,
pattern: Annotated[str, typer.Option(help="Glob pattern for JSONL files")] = "*.jsonl",
) -> None:
"""Ingest JSONL post files into the weekly-partitioned posts table."""
configure_logger(level="INFO")
logger.info("starting ingest-posts")
logger.info("path=%s batch_size=%d workers=%d pattern=%s", path, batch_size, workers, pattern)
if path.is_file():
ingest_file(path, batch_size=batch_size)
elif path.is_dir():
ingest_directory(path, batch_size=batch_size, max_workers=workers, pattern=pattern)
else:
typer.echo(f"Path does not exist: {path}", err=True)
raise typer.Exit(code=1)
logger.info("ingest-posts done")
def ingest_directory(
directory: Path,
*,
batch_size: int,
max_workers: int,
pattern: str = "*.jsonl",
) -> None:
"""Ingest all JSONL files in a directory using parallel workers."""
files = sorted(directory.glob(pattern))
if not files:
logger.warning("No JSONL files found in %s", directory)
return
logger.info("Found %d JSONL files to ingest", len(files))
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
SCHEMA = "main"
COLUMNS = (
"post_id",
"user_id",
"instance",
"date",
"text",
"langs",
"like_count",
"reply_count",
"repost_count",
"reply_to",
"replied_author",
"thread_root",
"thread_root_author",
"repost_from",
"reposted_author",
"quotes",
"quoted_author",
"labels",
"sent_label",
"sent_score",
)
INSERT_FROM_STAGING = f"""
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
ON CONFLICT (post_id, date) DO NOTHING
""" # noqa: S608
FAILED_INSERT = f"""
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
VALUES (%(raw_line)s, %(error)s)
""" # noqa: S608
def get_psycopg_connection() -> psycopg.Connection:
"""Create a raw psycopg3 connection from environment variables."""
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
return psycopg.connect(
dbname=database,
host=host,
port=int(port),
user=username,
password=password,
autocommit=False,
)
def ingest_file(path: Path, *, batch_size: int) -> None:
"""Ingest a single JSONL file into the posts table."""
log_trigger = max(100_000 // batch_size, 1)
failed_lines: list[dict] = []
try:
with get_psycopg_connection() as connection:
for index, batch in enumerate(read_jsonl_batches(path, batch_size, failed_lines), 1):
ingest_batch(connection, batch)
if index % log_trigger == 0:
logger.info("Ingested %d batches (%d rows) from %s", index, index * batch_size, path)
if failed_lines:
logger.warning("Recording %d malformed lines from %s", len(failed_lines), path.name)
with connection.cursor() as cursor:
cursor.executemany(FAILED_INSERT, failed_lines)
connection.commit()
except Exception:
logger.exception("Failed to ingest file: %s", path)
raise
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
if not batch:
return
try:
with connection.cursor() as cursor:
cursor.execute(f"""
CREATE TEMP TABLE IF NOT EXISTS staging
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
ON COMMIT DELETE ROWS
""")
cursor.execute("TRUNCATE pg_temp.staging")
with cursor.copy(f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN") as copy:
for row in batch:
copy.write_row(tuple(row.get(column) for column in COLUMNS))
cursor.execute(INSERT_FROM_STAGING)
connection.commit()
except Exception as error:
connection.rollback()
if len(batch) == 1:
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
with connection.cursor() as cursor:
cursor.execute(
FAILED_INSERT,
{
"raw_line": orjson.dumps(batch[0], default=str).decode(),
"error": str(error),
},
)
connection.commit()
return
midpoint = len(batch) // 2
ingest_batch(connection, batch[:midpoint])
ingest_batch(connection, batch[midpoint:])
def read_jsonl_batches(file_path: Path, batch_size: int, failed_lines: list[dict]) -> Iterator[list[dict]]:
"""Stream a JSONL file and yield batches of transformed rows."""
batch: list[dict] = []
with file_path.open("r", encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line:
continue
batch.extend(parse_line(line, file_path, failed_lines))
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
"""Parse a JSONL line, handling concatenated JSON objects."""
try:
yield transform_row(orjson.loads(line))
except orjson.JSONDecodeError:
if "}{" not in line:
logger.warning("Skipping malformed line in %s: %s", file_path.name, line[:120])
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
return
fragments = line.replace("}{", "}\n{").split("\n")
for fragment in fragments:
try:
yield transform_row(orjson.loads(fragment))
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
logger.warning("Skipping malformed fragment in %s: %s", file_path.name, fragment[:120])
failed_lines.append({"raw_line": fragment, "error": str(error)})
except Exception as error:
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
failed_lines.append({"raw_line": line, "error": str(error)})
def transform_row(raw: dict) -> dict:
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
raw["date"] = parse_date(raw["date"])
if raw.get("langs") is not None:
raw["langs"] = orjson.dumps(raw["langs"])
if raw.get("text") is not None:
raw["text"] = raw["text"].replace("\x00", "")
return raw
def parse_date(raw_date: int) -> datetime:
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
return datetime(
raw_date // 100000000,
(raw_date // 1000000) % 100,
(raw_date // 10000) % 100,
(raw_date // 100) % 100,
raw_date % 100,
tzinfo=UTC,
)
if __name__ == "__main__":
app()

View File

@@ -90,6 +90,13 @@ DATABASES: dict[str, DatabaseConfig] = {
base_class_name="SignalBotBase",
models_module="python.orm.signal_bot.models",
),
"data_science_dev": DatabaseConfig(
env_prefix="DATA_SCIENCE_DEV",
version_location="python/alembic/data_science_dev/versions",
base_module="python.orm.data_science_dev.base",
base_class_name="DataScienceDevBase",
models_module="python.orm.data_science_dev.models",
),
}

View File

@@ -1,10 +1,12 @@
"""ORM package exports."""
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.richie.base import RichieBase
from python.orm.signal_bot.base import SignalBotBase
from python.orm.van_inventory.base import VanInventoryBase
__all__ = [
"DataScienceDevBase",
"RichieBase",
"SignalBotBase",
"VanInventoryBase",

View File

@@ -0,0 +1,11 @@
"""Data science dev database ORM exports."""
from __future__ import annotations
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase, DataScienceDevTableBaseBig
__all__ = [
"DataScienceDevBase",
"DataScienceDevTableBase",
"DataScienceDevTableBaseBig",
]

View File

@@ -0,0 +1,52 @@
"""Data science dev database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class DataScienceDevBase(DeclarativeBase):
"""Base class for data_science_dev database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class _TableMixin:
"""Shared timestamp columns for all table bases."""
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with Integer primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with BigInteger primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)

View File

@@ -0,0 +1,14 @@
"""init."""
from python.orm.data_science_dev.congress.bill import Bill, BillText
from python.orm.data_science_dev.congress.legislator import Legislator, LegislatorSocialMedia
from python.orm.data_science_dev.congress.vote import Vote, VoteRecord
__all__ = [
"Bill",
"BillText",
"Legislator",
"LegislatorSocialMedia",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,66 @@
"""Bill model - legislation introduced in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.vote import Vote
class Bill(DataScienceDevTableBase):
"""Legislation with congress number, type, titles, status, and sponsor."""
__tablename__ = "bill"
congress: Mapped[int]
bill_type: Mapped[str]
number: Mapped[int]
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
status: Mapped[str | None]
status_at: Mapped[date | None]
sponsor_bioguide_id: Mapped[str | None]
subjects_top_term: Mapped[str | None]
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
bill_texts: Mapped[list[BillText]] = relationship(
"BillText",
back_populates="bill",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
Index("ix_bill_congress", "congress"),
)
class BillText(DataScienceDevTableBase):
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
__tablename__ = "bill_text"
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
version_code: Mapped[str]
version_name: Mapped[str | None]
text_content: Mapped[str | None]
date: Mapped[date | None]
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
__table_args__ = (UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),)

View File

@@ -0,0 +1,66 @@
"""Legislator model - members of Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.vote import VoteRecord
class Legislator(DataScienceDevTableBase):
"""Members of Congress with identification and current term info."""
__tablename__ = "legislator"
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None]
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
birthday: Mapped[date | None]
gender: Mapped[str | None]
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None]
current_chamber: Mapped[str | None]
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
"LegislatorSocialMedia",
back_populates="legislator",
cascade="all, delete-orphan",
)
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class LegislatorSocialMedia(DataScienceDevTableBase):
"""Social media account linked to a legislator."""
__tablename__ = "legislator_social_media"
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
platform: Mapped[str]
account_name: Mapped[str]
url: Mapped[str | None]
source: Mapped[str]
legislator: Mapped[Legislator] = relationship(back_populates="social_media_accounts")

View File

@@ -0,0 +1,79 @@
"""Vote model - roll call votes in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.bill import Bill
from python.orm.data_science_dev.congress.legislator import Legislator
from python.orm.data_science_dev.congress.vote import Vote
class VoteRecord(DataScienceDevBase):
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str]
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")
class Vote(DataScienceDevTableBase):
"""Roll call votes with counts and optional bill linkage."""
__tablename__ = "vote"
congress: Mapped[int]
chamber: Mapped[str]
session: Mapped[int]
number: Mapped[int]
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
vote_date: Mapped[date]
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress",
"chamber",
"session",
"number",
name="uq_vote_congress_chamber_session_number",
),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)

View File

@@ -0,0 +1,16 @@
"""Data science dev database ORM models."""
from __future__ import annotations
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
from python.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
from python.orm.data_science_dev.posts.tables import Posts
__all__ = [
"Bill",
"BillText",
"Legislator",
"Posts",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,11 @@
"""Posts module — weekly-partitioned posts table and partition ORM models."""
from __future__ import annotations
from python.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
from python.orm.data_science_dev.posts.tables import Posts
__all__ = [
"FailedIngestion",
"Posts",
]

View File

@@ -0,0 +1,33 @@
"""Shared column definitions for the posts partitioned table family."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, SmallInteger, Text
from sqlalchemy.orm import Mapped, mapped_column
class PostsColumns:
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger)
instance: Mapped[str]
date: Mapped[datetime] = mapped_column(primary_key=True)
text: Mapped[str] = mapped_column(Text)
langs: Mapped[str | None]
like_count: Mapped[int]
reply_count: Mapped[int]
repost_count: Mapped[int]
reply_to: Mapped[int | None] = mapped_column(BigInteger)
replied_author: Mapped[int | None] = mapped_column(BigInteger)
thread_root: Mapped[int | None] = mapped_column(BigInteger)
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
repost_from: Mapped[int | None] = mapped_column(BigInteger)
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
quotes: Mapped[int | None] = mapped_column(BigInteger)
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
labels: Mapped[str | None]
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
sent_score: Mapped[float | None]

View File

@@ -0,0 +1,17 @@
"""Table for storing JSONL lines that failed during post ingestion."""
from __future__ import annotations
from sqlalchemy import Text
from sqlalchemy.orm import Mapped, mapped_column
from python.orm.data_science_dev.base import DataScienceDevTableBase
class FailedIngestion(DataScienceDevTableBase):
"""Stores raw JSONL lines and their error messages when ingestion fails."""
__tablename__ = "failed_ingestion"
raw_line: Mapped[str] = mapped_column(Text)
error: Mapped[str] = mapped_column(Text)

View File

@@ -0,0 +1,71 @@
"""Dynamically generated ORM classes for each weekly partition of the posts table.
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
These are real ORM models tracked by Alembic autogenerate.
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
52 or 53 weeks, and week boundaries are always Monday to Monday.
"""
from __future__ import annotations
import sys
from datetime import UTC, datetime
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.data_science_dev.posts.columns import PostsColumns
PARTITION_START_YEAR = 2023
PARTITION_END_YEAR = 2026
_current_module = sys.modules[__name__]
def iso_weeks_in_year(year: int) -> int:
"""Return the number of ISO weeks in a given year (52 or 53)."""
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
return dec_28.isocalendar().week
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
"""Return (start, end) datetimes for an ISO week.
Start = Monday 00:00:00 UTC of the given ISO week.
End = Monday 00:00:00 UTC of the following ISO week.
"""
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
if week < iso_weeks_in_year(year):
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
else:
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
return start, end
def _build_partition_classes() -> dict[str, type]:
"""Generate one ORM class per ISO week partition."""
classes: dict[str, type] = {}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
class_name = f"PostsWeek{year}W{week:02d}"
table_name = f"posts_{year}_{week:02d}"
partition_class = type(
class_name,
(PostsColumns, DataScienceDevBase),
{
"__tablename__": table_name,
"__table_args__": ({"implicit_returning": False},),
},
)
classes[class_name] = partition_class
return classes
# Generate all partition classes and register them on this module
_partition_classes = _build_partition_classes()
for _name, _cls in _partition_classes.items():
setattr(_current_module, _name, _cls)
__all__ = list(_partition_classes.keys())

View File

@@ -0,0 +1,13 @@
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
from __future__ import annotations
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.data_science_dev.posts.columns import PostsColumns
class Posts(PostsColumns, DataScienceDevBase):
"""Parent partitioned table for posts, partitioned by week on `date`."""
__tablename__ = "posts"
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall
from python.orm.richie.congress import Bill, Legislator, Vote, VoteRecord
from python.orm.richie.contact import (
Contact,
ContactNeed,
@@ -13,17 +12,13 @@ from python.orm.richie.contact import (
)
__all__ = [
"Bill",
"Contact",
"ContactNeed",
"ContactRelationship",
"Legislator",
"Need",
"RelationshipType",
"RichieBase",
"TableBase",
"TableBaseBig",
"TableBaseSmall",
"Vote",
"VoteRecord",
]

View File

@@ -1,150 +0,0 @@
"""Congress Tracker database models."""
from __future__ import annotations
from datetime import date
from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import RichieBase, TableBase
class Legislator(TableBase):
"""Legislator model - members of Congress."""
__tablename__ = "legislator"
# Natural key - bioguide ID is the authoritative identifier
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
# Other IDs for cross-referencing
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None] # JSON array stored as string
# Name info
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
# Bio
birthday: Mapped[date | None]
gender: Mapped[str | None] # M/F
# Current term info (denormalized for query efficiency)
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None] # House only
current_chamber: Mapped[str | None] # rep/sen
# Relationships
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class Bill(TableBase):
"""Bill model - legislation introduced in Congress."""
__tablename__ = "bill"
# Composite natural key: congress + bill_type + number
congress: Mapped[int]
bill_type: Mapped[str] # hr, s, hres, sres, hjres, sjres
number: Mapped[int]
# Bill info
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
# Status
status: Mapped[str | None]
status_at: Mapped[date | None]
# Sponsor
sponsor_bioguide_id: Mapped[str | None]
# Subjects
subjects_top_term: Mapped[str | None]
# Relationships
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
__table_args__ = (
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
Index("ix_bill_congress", "congress"),
)
class Vote(TableBase):
"""Vote model - roll call votes in Congress."""
__tablename__ = "vote"
# Composite natural key: congress + chamber + session + number
congress: Mapped[int]
chamber: Mapped[str] # house/senate
session: Mapped[int]
number: Mapped[int]
# Vote details
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
# Timing
vote_date: Mapped[date]
# Vote counts (denormalized for efficiency)
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
# Related bill (optional - not all votes are on bills)
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
# Relationships
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)
class VoteRecord(RichieBase):
"""Association table: Vote <-> Legislator with position."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str] # Yea, Nay, Not Voting, Present
# Relationships
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")

View File

@@ -0,0 +1,25 @@
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
#
# Build:
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
#
# Run:
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
# -v $(pwd)/output:/workspace/output \
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
# -v /zfs/models/hf:/models \
# bill-finetune \
# --dataset /workspace/dataset.jsonl \
# --output-dir /workspace/output/qwen-bill-summarizer
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer
WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY python/prompt_bench/summarization_prompts.py python/prompt_bench/summarization_prompts.py
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"]

View File

@@ -0,0 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM."""

View File

@@ -0,0 +1,233 @@
"""Submit an OpenAI Batch API bill-summarization job over compressed text.
Reads the first N bills from a CSV with a `text_content` column, compresses
each via `bill_token_compression.compress_bill_text`, builds a JSONL file of
summarization requests, and submits it as an asynchronous Batch API job
against `/v1/chat/completions`. Also writes a CSV of per-bill pre/post-
compression token counts.
"""
from __future__ import annotations
import csv
import json
import logging
import re
import sys
from os import getenv
from pathlib import Path
from typing import Annotated
import httpx
import typer
from tiktoken import Encoding, get_encoding
from python.prompt_bench.bill_token_compression import compress_bill_text
from python.prompt_bench.summarization_prompts import SUMMARIZATION_SYSTEM_PROMPT, SUMMARIZATION_USER_TEMPLATE
logger = logging.getLogger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1"
def load_bills(csv_path: Path, count: int = 0) -> list[tuple[str, str]]:
"""Return (bill_id, text_content) tuples with non-empty text.
If `count` is 0 or negative, all rows are returned.
"""
csv.field_size_limit(sys.maxsize)
bills: list[tuple[str, str]] = []
with csv_path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
text_content = (row.get("text_content") or "").strip()
if not text_content:
continue
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
version_code = row.get("version_code") or ""
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
bills.append((unique_id, text_content))
if count > 0 and len(bills) >= count:
break
return bills
def safe_filename(value: str) -> str:
"""Make a string safe for use as a filename or batch custom_id."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
def build_request(custom_id: str, model: str, bill_text: str) -> dict:
"""Build one OpenAI batch request line."""
return {
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": model,
"messages": [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
],
},
}
def write_jsonl(path: Path, lines: list[dict]) -> None:
"""Write a list of dicts as JSONL."""
with path.open("w", encoding="utf-8") as handle:
for line in lines:
handle.write(json.dumps(line, ensure_ascii=False))
handle.write("\n")
def upload_file(client: httpx.Client, path: Path) -> str:
"""Upload a JSONL file to the OpenAI Files API and return its file id."""
with path.open("rb") as handle:
response = client.post(
f"{OPENAI_API_BASE}/files",
files={"file": (path.name, handle, "application/jsonl")},
data={"purpose": "batch"},
)
response.raise_for_status()
return response.json()["id"]
def prepare_requests(
bills: list[tuple[str, str]],
*,
model: str,
encoder: Encoding,
) -> tuple[list[dict], list[dict]]:
"""Build (request_lines, token_rows) from bills.
Each bill is compressed before being turned into a request line.
Each `token_rows` entry has chars + token counts for one bill so the caller
can write a per-bill CSV.
"""
request_lines: list[dict] = []
token_rows: list[dict] = []
for bill_id, text_content in bills:
raw_token_count = len(encoder.encode(text_content))
compressed_text = compress_bill_text(text_content)
compressed_token_count = len(encoder.encode(compressed_text))
token_rows.append(
{
"bill_id": bill_id,
"raw_chars": len(text_content),
"compressed_chars": len(compressed_text),
"raw_tokens": raw_token_count,
"compressed_tokens": compressed_token_count,
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None,
},
)
safe_id = safe_filename(bill_id)
request_lines.append(build_request(safe_id, model, compressed_text))
return request_lines, token_rows
def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
"""Write per-bill token counts to CSV. Returns (raw_total, compressed_total)."""
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(
handle,
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"],
)
writer.writeheader()
writer.writerows(token_rows)
raw_total = sum(row["raw_tokens"] for row in token_rows)
compressed_total = sum(row["compressed_tokens"] for row in token_rows)
return raw_total, compressed_total
def create_batch(client: httpx.Client, input_file_id: str, description: str) -> dict:
"""Create a batch job and return its full response payload."""
response = client.post(
f"{OPENAI_API_BASE}/batches",
json={
"input_file_id": input_file_id,
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"description": description},
},
)
response.raise_for_status()
return response.json()
def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path(
"output/openai_batch",
),
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
count: Annotated[int, typer.Option(help="Max bills to process, 0 = all")] = 0,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Submit an OpenAI Batch job of compressed bill summaries."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key:
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
raise typer.BadParameter(message)
if not csv_path.is_file():
message = f"CSV not found: {csv_path}"
raise typer.BadParameter(message)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("Loading %d bills from %s", count, csv_path)
bills = load_bills(csv_path, count)
if len(bills) < count:
logger.warning("Only %d bills available (requested %d)", len(bills), count)
encoder = get_encoding("o200k_base")
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
token_csv_path = output_dir / "token_counts.csv"
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows)
logger.info(
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
raw_tokens_total,
compressed_tokens_total,
(compressed_tokens_total / raw_tokens_total) if raw_tokens_total else 0.0,
token_csv_path,
)
jsonl_path = output_dir / "requests.jsonl"
write_jsonl(jsonl_path, request_lines)
logger.info("Wrote %s (%d bills)", jsonl_path, len(request_lines))
headers = {"Authorization": f"Bearer {api_key}"}
with httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client:
logger.info("Uploading JSONL")
file_id = upload_file(client, jsonl_path)
logger.info("Uploaded: %s", file_id)
logger.info("Creating batch")
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})")
logger.info("Batch created: %s", batch["id"])
metadata = {
"model": model,
"count": len(bills),
"jsonl": str(jsonl_path),
"input_file_id": file_id,
"batch_id": batch["id"],
"raw_tokens_total": raw_tokens_total,
"compressed_tokens_total": compressed_tokens_total,
"batch": batch,
}
metadata_path = output_dir / "batch.json"
metadata_path.write_text(json.dumps(metadata, indent=2))
logger.info("Wrote metadata to %s", metadata_path)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,162 @@
"""Lossless-ish text compression for Congressional bill text."""
from __future__ import annotations
import re
STATES = (
"Alabama",
"Alaska",
"Arizona",
"Arkansas",
"California",
"Colorado",
"Connecticut",
"Delaware",
"Florida",
"Georgia",
"Hawaii",
"Idaho",
"Illinois",
"Indiana",
"Iowa",
"Kansas",
"Kentucky",
"Louisiana",
"Maine",
"Maryland",
"Massachusetts",
"Michigan",
"Minnesota",
"Mississippi",
"Missouri",
"Montana",
"Nebraska",
"Nevada",
"New Hampshire",
"New Jersey",
"New Mexico",
"New York",
"North Carolina",
"North Dakota",
"Ohio",
"Oklahoma",
"Oregon",
"Pennsylvania",
"Rhode Island",
"South Carolina",
"South Dakota",
"Tennessee",
"Texas",
"Utah",
"Vermont",
"Virginia",
"Washington",
"West Virginia",
"Wisconsin",
"Wyoming",
"Puerto Rico",
"Guam",
"American Samoa",
"District of Columbia",
"US Virgin Islands",
)
STATE_PATTERNS = [(re.compile(re.escape(state), re.IGNORECASE), state) for state in STATES]
def normalize_state_names(text: str) -> str:
"""Replace any casing of state names with title case."""
for pattern, replacement in STATE_PATTERNS:
text = pattern.sub(replacement, text)
return text
def strip_number_commas(text: str) -> str:
"""Remove commas from numeric thousands separators."""
return re.sub(r"(\d{1,3}(?:,\d{3})+)", lambda match: match.group().replace(",", ""), text)
def strip_horizontal_rules(text: str) -> str:
"""Remove ASCII horizontal-rule lines built from underscores, dashes, equals, or asterisks."""
return re.sub(r"^\s*[_\-=\*]{3,}\s*$", "", text, flags=re.MULTILINE)
def collapse_double_dashes(text: str) -> str:
"""Replace ``--`` em-dash stand-ins with a single space so they don't tokenize oddly."""
return text.replace("--", " ")
def collapse_inline_whitespace(text: str) -> str:
"""Collapse runs of horizontal whitespace (spaces, tabs) into a single space, leaving newlines intact."""
return re.sub(r"[^\S\n]+", " ", text)
def collapse_blank_lines(text: str) -> str:
"""Collapse three-or-more consecutive newlines down to a blank-line separator."""
return re.sub(r"\n{3,}", "\n\n", text)
def trim_line_edges(text: str) -> str:
"""Strip spaces immediately before and after newline characters on every line."""
text = re.sub(r" +\n", "\n", text)
return re.sub(r"\n +", "\n", text)
def shorten_section_markers(text: str) -> str:
"""Rewrite ``Sec. 12.`` style section headings as the more compact ``SEC 12``."""
return re.sub(r"(?i)sec\.\s*(\d+[a-zA-Z]?)\.", r"SEC \1", text)
def unwrap_parens(text: str) -> str:
"""Strip parentheses around short alphanumeric labels like ``(a)`` or ``(12)``."""
return re.sub(r"\(([a-zA-Z0-9]+)\)", r"\1", text)
def strip_typeset_quotes(text: str) -> str:
"""Remove the `` and '' typeset quote markers used in the GPO bill format."""
return text.replace("``", "").replace("''", "")
def normalize_usc_acronym(text: str) -> str:
"""Collapse ``U.S.C.`` to ``USC`` to save tokens on the common citation."""
return text.replace("U.S.C.", "USC")
def normalize_us_acronym(text: str) -> str:
"""Normalize the various ``U.S.``/``U. S.`` spellings to the bare ``US`` form."""
for acronym in ("U. S.", "u. s.", "U.S. ", "u.s. "):
text = text.replace(acronym, "US ")
return text
def collapse_ellipses(text: str) -> str:
"""Collapse runs of two-or-more periods (``...``, ``....``) down to a single period."""
return re.sub(r"\.{2,}", ".", text)
COMPRESSION_STEPS = (
strip_horizontal_rules,
collapse_double_dashes,
collapse_inline_whitespace,
collapse_blank_lines,
trim_line_edges,
shorten_section_markers,
unwrap_parens,
strip_typeset_quotes,
normalize_usc_acronym,
normalize_us_acronym,
strip_number_commas,
collapse_ellipses,
normalize_state_names,
)
def compress_bill_text(text: str) -> str:
"""Apply lossless-ish whitespace and boilerplate compression to bill text.
Runs every transform in :data:`COMPRESSION_STEPS` in order, then strips
leading/trailing whitespace from the final result.
"""
for step in COMPRESSION_STEPS:
text = step(text)
return text.strip()

View File

@@ -0,0 +1,236 @@
"""Run two interactive OpenAI chat-completion sweeps over bill text.
Reads the first N bills from a CSV with a `text_content` column and sends two
sweeps through `/v1/chat/completions` concurrently — one with the raw bill
text, one with the compressed bill text. Each request's prompt is saved to
disk alongside the OpenAI response id so the prompts and responses can be
correlated later.
"""
from __future__ import annotations
import csv
import json
import logging
import re
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from os import getenv
from pathlib import Path
from typing import Annotated
import httpx
import typer
from python.prompt_bench.bill_token_compression import compress_bill_text
from python.prompt_bench.summarization_prompts import SUMMARIZATION_SYSTEM_PROMPT, SUMMARIZATION_USER_TEMPLATE
logger = logging.getLogger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1"
DEFAULT_MODEL = "gpt-5.4-mini"
DEFAULT_COUNT = 100
SEED = 42
def load_bills(csv_path: Path, count: int) -> list[tuple[str, str]]:
"""Return up to `count` (bill_id, text_content) tuples with non-empty text."""
csv.field_size_limit(sys.maxsize)
bills: list[tuple[str, str]] = []
with csv_path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
text_content = (row.get("text_content") or "").strip()
if not text_content:
continue
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
version_code = row.get("version_code") or ""
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
bills.append((unique_id, text_content))
if len(bills) >= count:
break
return bills
def build_messages(bill_text: str) -> list[dict]:
"""Return the system + user message pair for a bill."""
return [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
]
def safe_filename(value: str) -> str:
"""Make a string safe for use as a filename."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
def run_one_request(
client: httpx.Client,
*,
bill_id: str,
label: str,
bill_text: str,
model: str,
output_path: Path,
) -> tuple[bool, float, str | None]:
"""Send one chat-completion request and persist prompt + response.
Returns (success, elapsed_seconds, response_id).
"""
messages = build_messages(bill_text)
payload = {
"model": model,
"messages": messages,
"seed": SEED,
}
start = time.monotonic()
record: dict = {
"bill_id": bill_id,
"label": label,
"model": model,
"seed": SEED,
"input_chars": len(bill_text),
"messages": messages,
}
try:
response = client.post(f"{OPENAI_API_BASE}/chat/completions", json=payload)
response.raise_for_status()
body = response.json()
except httpx.HTTPStatusError as error:
elapsed = time.monotonic() - start
record["error"] = {
"status_code": error.response.status_code,
"body": error.response.text,
"elapsed_seconds": elapsed,
}
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.exception("HTTP error for %s/%s after %.2fs", label, bill_id, elapsed)
return False, elapsed, None
except Exception as error:
elapsed = time.monotonic() - start
record["error"] = {"message": str(error), "elapsed_seconds": elapsed}
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.exception("Failed: %s/%s after %.2fs", label, bill_id, elapsed)
return False, elapsed, None
elapsed = time.monotonic() - start
response_id = body.get("id")
record["response_id"] = response_id
record["elapsed_seconds"] = elapsed
record["usage"] = body.get("usage")
record["response"] = body
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.info("Done: %s/%s id=%s in %.2fs", label, bill_id, response_id, elapsed)
return True, elapsed, response_id
def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path(
"output/openai_runs",
),
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key:
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
raise typer.BadParameter(message)
if not csv_path.is_file():
message = f"CSV not found: {csv_path}"
raise typer.BadParameter(message)
compressed_dir = output_dir / "compressed"
uncompressed_dir = output_dir / "uncompressed"
compressed_dir.mkdir(parents=True, exist_ok=True)
uncompressed_dir.mkdir(parents=True, exist_ok=True)
logger.info("Loading %d bills from %s", count, csv_path)
bills = load_bills(csv_path, count)
if len(bills) < count:
logger.warning("Only %d bills available (requested %d)", len(bills), count)
tasks: list[tuple[str, str, str, Path]] = []
for bill_id, text_content in bills:
filename = f"{safe_filename(bill_id)}.json"
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename))
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename))
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
headers = {"Authorization": f"Bearer {api_key}"}
completed = 0
failed = 0
index: list[dict] = []
wall_start = time.monotonic()
with (
httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client,
ThreadPoolExecutor(
max_workers=concurrency,
) as executor,
):
future_to_task = {
executor.submit(
run_one_request,
client,
bill_id=bill_id,
label=label,
bill_text=bill_text,
model=model,
output_path=output_path,
): (bill_id, label, output_path)
for bill_id, label, bill_text, output_path in tasks
}
for future in as_completed(future_to_task):
bill_id, label, output_path = future_to_task[future]
success, elapsed, response_id = future.result()
if success:
completed += 1
else:
failed += 1
index.append(
{
"bill_id": bill_id,
"label": label,
"response_id": response_id,
"elapsed_seconds": elapsed,
"success": success,
"path": str(output_path),
},
)
wall_elapsed = time.monotonic() - wall_start
summary = {
"model": model,
"count": len(bills),
"completed": completed,
"failed": failed,
"wall_seconds": wall_elapsed,
"concurrency": concurrency,
"results": index,
}
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, indent=2))
logger.info(
"Done: completed=%d failed=%d wall=%.1fs summary=%s",
completed,
failed,
wall_elapsed,
summary_path,
)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM."""

View File

@@ -0,0 +1,165 @@
"""Docker container lifecycle management for Unsloth fine-tuning."""
from __future__ import annotations
import logging
import subprocess
from pathlib import Path
from typing import Annotated
import typer
from python.prompt_bench.containers.lib import check_gpu_free
logger = logging.getLogger(__name__)
CONTAINER_NAME = "bill-finetune"
FINETUNE_IMAGE = "bill-finetune:latest"
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune"
DEFAULT_HF_CACHE = Path("/zfs/models/hf")
def build_image() -> None:
"""Build the fine-tuning Docker image."""
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
result = subprocess.run(
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."],
text=True,
check=False,
)
if result.returncode != 0:
message = "Failed to build fine-tuning image"
raise RuntimeError(message)
logger.info("Image built: %s", FINETUNE_IMAGE)
def start_finetune(
*,
dataset_path: Path,
output_dir: Path,
hf_cache: Path = DEFAULT_HF_CACHE,
) -> None:
"""Run the fine-tuning container.
Args:
dataset_path: Host path to the fine-tuning JSONL dataset.
output_dir: Host path where the trained model will be saved.
hf_cache: Host path to HuggingFace model cache (bind-mounted to avoid re-downloading).
validation_split: Fraction of data held out for validation.
"""
dataset_path = dataset_path.resolve()
output_dir = output_dir.resolve()
if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}"
raise FileNotFoundError(message)
output_dir.mkdir(parents=True, exist_ok=True)
stop_finetune()
hf_cache = hf_cache.resolve()
hf_cache.mkdir(parents=True, exist_ok=True)
command = [
"docker",
"run",
"--name",
CONTAINER_NAME,
"--device=nvidia.com/gpu=all",
"--ipc=host",
"-v",
f"{hf_cache}:/root/.cache/huggingface",
"-v",
f"{output_dir}:/workspace/output/qwen-bill-summarizer",
"-v",
f"{dataset_path}:/workspace/dataset.jsonl:ro",
FINETUNE_IMAGE,
"--dataset",
"/workspace/dataset.jsonl",
"--output-dir",
"/workspace/output/qwen-bill-summarizer",
]
logger.info("Starting fine-tuning container")
logger.info(" Dataset: %s", dataset_path)
logger.info(" Output: %s", output_dir)
result = subprocess.run(command, text=True, check=False)
if result.returncode != 0:
message = f"Fine-tuning container exited with code {result.returncode}"
raise RuntimeError(message)
logger.info("Fine-tuning complete. Model saved to %s", output_dir)
def stop_finetune() -> None:
"""Stop and remove the fine-tuning container."""
logger.info("Stopping fine-tuning container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
def logs_finetune() -> str | None:
"""Return recent logs from the fine-tuning container, or None if not running."""
result = subprocess.run(
["docker", "logs", "--tail", "50", CONTAINER_NAME],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
return None
return result.stdout + result.stderr
app = typer.Typer(help="Fine-tuning container management.")
@app.command()
def build() -> None:
"""Build the fine-tuning Docker image."""
build_image()
@app.command()
def run(
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path(
"/home/richie/dotfiles/data/finetune_dataset.jsonl"
),
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path(
"/home/richie/dotfiles/data/output/qwen-bill-summarizer",
),
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run fine-tuning inside a Docker container."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
check_gpu_free()
start_finetune(
dataset_path=dataset,
output_dir=output_dir,
hf_cache=hf_cache,
)
@app.command()
def stop() -> None:
"""Stop and remove the fine-tuning container."""
stop_finetune()
@app.command()
def logs() -> None:
"""Show recent logs from the fine-tuning container."""
output = logs_finetune()
if output is None:
typer.echo("No running fine-tuning container found.")
raise typer.Exit(code=1)
typer.echo(output)
def cli() -> None:
"""Typer entry point."""
app()
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import logging
import subprocess
logger = logging.getLogger(__name__)
def check_gpu_free() -> None:
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
result = subprocess.run(
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
return
processes = result.stdout.strip()
if processes:
logger.warning("GPU processes detected:\n%s", processes)
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")

View File

@@ -0,0 +1,70 @@
"""Docker container lifecycle management for vLLM."""
from __future__ import annotations
import logging
import subprocess
logger = logging.getLogger(__name__)
CONTAINER_NAME = "vllm-bench"
VLLM_IMAGE = "vllm/vllm-openai:v0.19.0"
def start_vllm(
*,
model: str,
port: int,
model_dir: str,
gpu_memory_utilization: float,
) -> None:
"""Start a vLLM container serving the given model.
Args:
model: HuggingFace model directory name (relative to model_dir).
port: Host port to bind.
model_dir: Host path containing HuggingFace model directories.
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
"""
command = [
"docker",
"run",
"-d",
"--name",
CONTAINER_NAME,
"--device=nvidia.com/gpu=all",
"--ipc=host",
"-v",
f"{model_dir}:/models",
"-p",
f"{port}:8000",
VLLM_IMAGE,
"--model",
f"/models/{model}",
"--served-model-name",
model,
"--gpu-memory-utilization",
str(gpu_memory_utilization),
"--max-model-len",
"4096",
]
logger.info("Starting vLLM container with model: %s", model)
stop_vllm()
result = subprocess.run(command, capture_output=True, text=True, check=False)
if result.returncode != 0:
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
raise RuntimeError(msg)
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
def stop_vllm() -> None:
"""Stop and remove the vLLM benchmark container."""
logger.info("Stopping vLLM container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(
["docker", "network", "disconnect", "-f", "bridge", CONTAINER_NAME],
capture_output=True,
check=False,
)
logger.info("vLLM container stopped and removed")

View File

@@ -0,0 +1,75 @@
"""HuggingFace model downloader."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Annotated
import typer
from huggingface_hub import snapshot_download
from python.prompt_bench.models import BenchmarkConfig
logger = logging.getLogger(__name__)
def local_model_path(repo: str, model_dir: str) -> Path:
"""Return the local directory path for a HuggingFace repo."""
return Path(model_dir) / repo
def is_model_present(repo: str, model_dir: str) -> bool:
"""Check if a model has already been downloaded."""
path = local_model_path(repo, model_dir)
return path.exists() and any(path.iterdir())
def download_model(repo: str, model_dir: str) -> Path:
"""Download a HuggingFace model to the local model directory.
Skips the download if the model directory already exists and contains files.
"""
local_path = local_model_path(repo, model_dir)
if is_model_present(repo, model_dir):
logger.info("Model already exists: %s", local_path)
return local_path
logger.info("Downloading model: %s -> %s", repo, local_path)
snapshot_download(
repo_id=repo,
local_dir=str(local_path),
)
logger.info("Download complete: %s", repo)
return local_path
def download_all(config: BenchmarkConfig) -> None:
"""Download every model listed in the config, top to bottom."""
for repo in config.models:
download_model(repo, config.model_dir)
def main(
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Download all models listed in the benchmark config."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
benchmark_config = BenchmarkConfig.from_toml(config)
download_all(benchmark_config)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,214 @@
"""Fine-tune Qwen 3.5 4B on bill summarization data using Unsloth.
Loads a ChatML-style JSONL dataset (system/user/assistant messages),
applies QLoRA with 4-bit quantization, and saves the merged model
in HuggingFace format. Designed for a single RTX 3090 (24GB).
Usage:
python -m python.prompt_bench.finetune \
--dataset output/finetune_dataset.jsonl \
--output-dir output/qwen-bill-summarizer
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated
import tomllib
import typer
from unsloth import FastLanguageModel
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
logger = logging.getLogger(__name__)
@dataclass
class LoraConfig:
"""LoRA adapter hyperparameters."""
rank: int
alpha: int
dropout: float
targets: list[str]
@dataclass
class TrainingConfig:
"""Training loop hyperparameters."""
learning_rate: float
epochs: int
batch_size: int
gradient_accumulation: int
max_seq_length: int
warmup_ratio: float
weight_decay: float
logging_steps: int
save_steps: int
@dataclass
class FinetuneConfig:
"""Top-level finetune configuration."""
base_model: str
lora: LoraConfig
training: TrainingConfig
@classmethod
def from_toml(cls, config_path: Path) -> FinetuneConfig:
"""Load finetune config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["finetune"]
return cls(
base_model=raw["base_model"],
lora=LoraConfig(**raw["lora"]),
training=TrainingConfig(**raw["training"]),
)
def _messages_to_chatml(messages: list[dict]) -> str:
r"""Convert a message list to Qwen ChatML format.
Produces:
<|im_start|>system\n...\n<|im_end|>
<|im_start|>user\n...\n<|im_end|>
<|im_start|>assistant\n...\n<|im_end|>
"""
parts = []
for message in messages:
role = message["role"]
content = message["content"]
parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
return "\n".join(parts)
def load_dataset_from_jsonl(path: Path) -> Dataset:
"""Load a ChatML JSONL file into a HuggingFace Dataset.
Each line must have {"messages": [{"role": ..., "content": ...}, ...]}.
Pre-formats into a `text` column with the Qwen ChatML template applied,
which SFTTrainer consumes directly.
"""
records = []
with path.open(encoding="utf-8") as handle:
for raw_line in handle:
stripped = raw_line.strip()
if stripped:
entry = json.loads(stripped)
records.append({"text": _messages_to_chatml(entry["messages"])})
logger.info("Loaded %d examples from %s", len(records), path)
return Dataset.from_list(records)
def main(
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path(
"output/finetune_dataset.jsonl",
),
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1,
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
"output/qwen-bill-summarizer",
),
config_path: Annotated[
Path,
typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml",
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}"
raise typer.BadParameter(message)
config = FinetuneConfig.from_toml(config_path)
logger.info("Loading base model: %s", config.base_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.base_model,
max_seq_length=config.training.max_seq_length,
load_in_4bit=True,
dtype=None,
)
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
model = FastLanguageModel.get_peft_model(
model,
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
target_modules=config.lora.targets,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
full_dataset = load_dataset_from_jsonl(dataset_path)
split = full_dataset.train_test_split(test_size=validation_split, seed=42)
train_dataset = split["train"]
validation_dataset = split["test"]
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.training.epochs,
per_device_train_batch_size=config.training.batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation,
learning_rate=config.training.learning_rate,
warmup_ratio=config.training.warmup_ratio,
weight_decay=config.training.weight_decay,
lr_scheduler_type="cosine",
logging_steps=config.training.logging_steps,
save_steps=config.training.save_steps,
save_total_limit=3,
eval_strategy="steps",
eval_steps=config.training.save_steps,
load_best_model_at_end=True,
bf16=True,
optim="adamw_8bit",
seed=42,
report_to="none",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
args=training_args,
max_seq_length=config.training.max_seq_length,
packing=True,
)
logger.info(
"Starting training: %d train, %d val, %d epochs",
len(train_dataset),
len(validation_dataset),
config.training.epochs,
)
trainer.train()
merged_path = str(output_dir / "merged")
logger.info("Saving merged model to %s", merged_path)
model.save_pretrained_merged(merged_path, tokenizer, save_method="merged_16bit")
if save_gguf:
gguf_path = str(output_dir / "gguf")
logger.info("Saving GGUF to %s", gguf_path)
model.save_pretrained_gguf(gguf_path, tokenizer, quantization_method="q4_k_m")
logger.info("Done! Model saved to %s", output_dir)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

215
python/prompt_bench/main.py Normal file
View File

@@ -0,0 +1,215 @@
"""CLI entry point for the prompt benchmarking system."""
from __future__ import annotations
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Annotated
import typer
from python.prompt_bench.containers.lib import check_gpu_free
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm
from python.prompt_bench.downloader import is_model_present
from python.prompt_bench.models import BenchmarkConfig
from python.prompt_bench.vllm_client import VLLMClient
logger = logging.getLogger(__name__)
def discover_prompts(input_dir: Path) -> list[Path]:
"""Find all .txt files in the input directory."""
prompts = list(input_dir.glob("*.txt"))
if not prompts:
message = f"No .txt files found in {input_dir}"
raise FileNotFoundError(message)
return prompts
def _run_prompt(
client: VLLMClient,
prompt_path: Path,
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
) -> tuple[bool, float]:
"""Run a single prompt. Returns (success, elapsed_seconds)."""
filename = prompt_path.name
output_path = model_output / filename
start = time.monotonic()
try:
prompt_text = prompt_path.read_text()
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
output_path.write_text(response)
elapsed = time.monotonic() - start
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
except Exception:
elapsed = time.monotonic() - start
error_path = model_output / f"{filename}.error"
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
error_path.write_text(f"Error processing {filename}")
return False, elapsed
return True, elapsed
def benchmark_model(
client: VLLMClient,
prompts: list[Path],
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
concurrency: int,
) -> tuple[int, int]:
"""Run all prompts against a single model in parallel.
vLLM batches concurrent requests internally, so submitting many at once is
significantly faster than running them serially.
"""
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
skipped = len(prompts) - len(pending)
if skipped:
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
if not pending:
logger.info("Nothing to do for %s", repo)
return 0, 0
completed = 0
failed = 0
latencies: list[float] = []
wall_start = time.monotonic()
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = [
executor.submit(
_run_prompt,
client,
prompt_path,
repo=repo,
model_dir_name=model_dir_name,
model_output=model_output,
temperature=temperature,
)
for prompt_path in pending
]
for future in as_completed(futures):
success, elapsed = future.result()
latencies.append(elapsed)
if success:
completed += 1
else:
failed += 1
wall_elapsed = time.monotonic() - wall_start
attempted = completed + failed
avg_latency = sum(latencies) / attempted
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
timing = {
"repo": repo,
"wall_seconds": wall_elapsed,
"attempted": attempted,
"completed": completed,
"failed": failed,
"avg_latency_seconds": avg_latency,
"throughput_prompts_per_second": throughput,
"concurrency": concurrency,
}
timing_path = model_output / "_timing.json"
timing_path.write_text(json.dumps(timing, indent=2))
return completed, failed
def run_benchmark(
config: BenchmarkConfig,
input_dir: Path,
output_dir: Path,
) -> None:
"""Execute the benchmark across all models and prompts."""
prompts = discover_prompts(input_dir)
logger.info("Found %d prompts in %s", len(prompts), input_dir)
check_gpu_free()
total_completed = 0
total_failed = 0
for repo in config.models:
if not is_model_present(repo, config.model_dir):
logger.warning("Skipping (not downloaded): %s", repo)
continue
model_output = output_dir / repo
model_output.mkdir(parents=True, exist_ok=True)
logger.info("=== Benchmarking model: %s ===", repo)
stop_vllm()
try:
start_vllm(
model=repo,
port=config.port,
model_dir=config.model_dir,
gpu_memory_utilization=config.gpu_memory_utilization,
)
except RuntimeError:
logger.exception("Failed to start vLLM for %s, skipping", repo)
continue
logger.info("vLLM started for %s", repo)
try:
with VLLMClient(port=config.port, timeout=config.timeout) as client:
client.wait_ready(max_wait=config.vllm_startup_timeout)
completed, failed = benchmark_model(
client,
prompts,
repo=repo,
model_dir_name=repo,
model_output=model_output,
temperature=config.temperature,
concurrency=config.concurrency,
)
total_completed += completed
total_failed += failed
finally:
stop_vllm()
logger.info("=== Benchmark complete ===")
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
def main(
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run prompts through multiple LLMs via vLLM and save results."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not input_dir.is_dir():
message = f"Input directory does not exist: {input_dir}"
raise typer.BadParameter(message)
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
benchmark_config = BenchmarkConfig.from_toml(config)
output_dir.mkdir(parents=True, exist_ok=True)
run_benchmark(benchmark_config, input_dir, output_dir)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,30 @@
"""Pydantic models for benchmark configuration."""
from __future__ import annotations
import tomllib
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from pathlib import Path
class BenchmarkConfig(BaseModel):
"""Top-level benchmark configuration loaded from TOML."""
models: list[str]
model_dir: str = "/zfs/models/hf"
port: int = 8000
gpu_memory_utilization: float = 0.90
temperature: float = 0.0
timeout: int = 300
concurrency: int = 4
vllm_startup_timeout: int = 900
@classmethod
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
"""Load benchmark config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bench"]
return cls(**raw)

View File

@@ -0,0 +1,34 @@
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
EXTRACTION RULES:
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
OUTPUT FORMAT — plain structured text, not JSON:
OPERATIVE ACTIONS:
[Numbered list of what the bill actually does, one action per line, max 20 words each]
AFFECTED POPULATIONS:
[Who gains something, who loses something, or whose behavior is regulated]
MECHANISMS:
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
POLICY THREADS:
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
SYMBOLIC/PROCEDURAL ONLY:
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
BILL TEXT:
{text_content}"""

View File

@@ -0,0 +1,114 @@
"""Build a fine-tuning JSONL dataset from batch request + output files.
Joins the original request JSONL (system + user messages) with the batch
output JSONL (assistant completions) by custom_id to produce a ChatML-style
messages JSONL suitable for fine-tuning.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Annotated
import typer
logger = logging.getLogger(__name__)
HTTP_OK = 200
def load_requests(path: Path) -> dict[str, list[dict]]:
"""Parse request JSONL into {custom_id: messages}."""
results: dict[str, list[dict]] = {}
with path.open(encoding="utf-8") as handle:
for raw_line in handle:
stripped = raw_line.strip()
if not stripped:
continue
record = json.loads(stripped)
custom_id = record["custom_id"]
messages = record["body"]["messages"]
results[custom_id] = messages
return results
def load_completions(path: Path) -> dict[str, str]:
"""Parse batch output JSONL into {custom_id: assistant_content}."""
results: dict[str, str] = {}
with path.open(encoding="utf-8") as handle:
for line_number, raw_line in enumerate(handle, 1):
stripped = raw_line.strip()
if not stripped:
continue
record = json.loads(stripped)
custom_id = record["custom_id"]
response = record.get("response", {})
if response.get("status_code") != HTTP_OK:
logger.warning("Skipping %s (line %d): status %s", custom_id, line_number, response.get("status_code"))
continue
body = response.get("body", {})
choices = body.get("choices", [])
if not choices:
logger.warning("Skipping %s (line %d): no choices", custom_id, line_number)
continue
content = choices[0].get("message", {}).get("content", "")
if not content:
logger.warning("Skipping %s (line %d): empty content", custom_id, line_number)
continue
results[custom_id] = content
return results
def main(
requests_path: Annotated[Path, typer.Option("--requests", help="Batch request JSONL")] = Path(
"output/openai_batch/requests.jsonl",
),
batch_output: Annotated[Path, typer.Option("--batch-output", help="Batch output JSONL")] = Path(
"batch_69d84558d91c819091d53f08d78f9fd6_output.jsonl",
),
output_path: Annotated[Path, typer.Option("--output", help="Fine-tuning JSONL output")] = Path(
"output/finetune_dataset.jsonl",
),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Build fine-tuning dataset by joining request and output JSONL files."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger.info("Loading requests from %s", requests_path)
requests = load_requests(requests_path)
logger.info("Loaded %d requests", len(requests))
logger.info("Loading completions from %s", batch_output)
completions = load_completions(batch_output)
logger.info("Loaded %d completions", len(completions))
output_path.parent.mkdir(parents=True, exist_ok=True)
matched = 0
skipped = 0
with output_path.open("w", encoding="utf-8") as handle:
for custom_id, messages in requests.items():
assistant_content = completions.get(custom_id)
if assistant_content is None:
skipped += 1
continue
example = {
"messages": [*messages, {"role": "assistant", "content": assistant_content}],
}
handle.write(json.dumps(example, ensure_ascii=False))
handle.write("\n")
matched += 1
logger.info("Wrote %d examples to %s (skipped %d unmatched)", matched, output_path, skipped)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,97 @@
"""Sum token usage across compressed and uncompressed run directories."""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Annotated
import typer
logger = logging.getLogger(__name__)
@dataclass
class UsageTotals:
"""Aggregate usage counters for a directory of run records."""
files: int = 0
errors: int = 0
prompt_tokens: int = 0
cached_tokens: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
total_tokens: int = 0
per_file: list[tuple[str, int, int, int]] = field(default_factory=list)
def tally_directory(directory: Path) -> UsageTotals:
"""Return aggregated usage stats for every JSON record in a directory."""
totals = UsageTotals()
decoder = json.JSONDecoder()
for path in sorted(directory.glob("*.json")):
text = path.read_text().lstrip()
record, _ = decoder.raw_decode(text)
totals.files += 1
usage = record.get("usage")
if not usage:
totals.errors += 1
continue
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
cached_tokens = (usage.get("prompt_tokens_details") or {}).get("cached_tokens", 0)
reasoning_tokens = (usage.get("completion_tokens_details") or {}).get("reasoning_tokens", 0)
totals.prompt_tokens += prompt_tokens
totals.completion_tokens += completion_tokens
totals.total_tokens += total_tokens
totals.cached_tokens += cached_tokens
totals.reasoning_tokens += reasoning_tokens
totals.per_file.append((path.name, prompt_tokens, completion_tokens, total_tokens))
return totals
def log_totals(label: str, totals: UsageTotals) -> None:
"""Log a one-block summary for a directory."""
counted = totals.files - totals.errors
average_total = totals.total_tokens / counted if counted else 0
logger.info("[%s]", label)
logger.info(" files : %d (with usage: %d, errors: %d)", totals.files, counted, totals.errors)
logger.info(" prompt tokens : %d", totals.prompt_tokens)
logger.info(" cached tokens : %d", totals.cached_tokens)
logger.info(" completion tok : %d", totals.completion_tokens)
logger.info(" reasoning tok : %d", totals.reasoning_tokens)
logger.info(" total tokens : %d", totals.total_tokens)
logger.info(" avg total/file : %.1f", average_total)
def main(
runs_dir: Annotated[Path, typer.Option("--runs-dir")] = Path("output/openai_runs_temp_1"),
log_level: Annotated[str, typer.Option("--log-level")] = "INFO",
) -> None:
"""Print token usage totals for the compressed and uncompressed run directories."""
logging.basicConfig(level=log_level, format="%(message)s")
grand = UsageTotals()
for label in ("compressed", "uncompressed"):
directory = runs_dir / label
if not directory.is_dir():
logger.warning("%s: directory not found at %s", label, directory)
continue
totals = tally_directory(directory)
log_totals(label, totals)
grand.files += totals.files
grand.errors += totals.errors
grand.prompt_tokens += totals.prompt_tokens
grand.cached_tokens += totals.cached_tokens
grand.completion_tokens += totals.completion_tokens
grand.reasoning_tokens += totals.reasoning_tokens
grand.total_tokens += totals.total_tokens
log_totals("grand total", grand)
if __name__ == "__main__":
typer.run(main)

View File

@@ -0,0 +1,68 @@
"""OpenAI-compatible client for vLLM's API."""
from __future__ import annotations
import logging
import time
from typing import Self
import httpx
logger = logging.getLogger(__name__)
READY_POLL_INTERVAL = 2.0
class VLLMClient:
"""Talk to a vLLM server via its OpenAI-compatible API.
Args:
host: vLLM host.
port: vLLM port.
timeout: Per-request timeout in seconds.
"""
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
"""Create a client connected to a vLLM server."""
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
def wait_ready(self, max_wait: int) -> None:
"""Poll /v1/models until the server is ready or timeout."""
deadline = time.monotonic() + max_wait
while time.monotonic() < deadline:
try:
response = self._client.get("/v1/models")
if response.is_success:
logger.info("vLLM server is ready")
return
except httpx.TransportError:
pass
time.sleep(READY_POLL_INTERVAL)
msg = f"vLLM server not ready after {max_wait}s"
raise TimeoutError(msg)
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
"""Send a prompt to /v1/completions and return the response text."""
payload = {
"model": model,
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
}
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
response = self._client.post("/v1/completions", json=payload)
response.raise_for_status()
data = response.json()
return data["choices"][0]["text"]
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client on exit."""
self.close()

View File

@@ -63,9 +63,9 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -99,9 +99,9 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -139,9 +139,9 @@ class DeviceRegistry:
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
@@ -150,7 +150,7 @@ class DeviceRegistry:
if any(record.name == role for record in device.roles):
return True
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
role_record = session.scalars(select(RoleRecord).where(RoleRecord.name == role)).one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
@@ -165,9 +165,9 @@ class DeviceRegistry:
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
@@ -182,16 +182,16 @@ class DeviceRegistry:
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(role) for role in roles]
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
records = session.scalars(select(RoleRecord).where(RoleRecord.name.in_(role_names))).all()
device.roles = records
session.commit()
self._update_cache(phone_number, device)
@@ -203,7 +203,7 @@ class DeviceRegistry:
def list_devices(self) -> list[SignalDevice]:
"""Return all known devices."""
with Session(self.engine) as session:
return list(session.execute(select(SignalDevice)).scalars().all())
return list(session.scalars(select(SignalDevice)).all())
def sync_identities(self) -> None:
"""Pull identity list from signal-cli and record any new ones."""
@@ -226,9 +226,7 @@ class DeviceRegistry:
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
return session.scalars(select(SignalDevice).where(SignalDevice.phone_number == phone_number)).one_or_none()
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
@@ -244,9 +242,9 @@ class DeviceRegistry:
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = session.execute(
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
).one_or_none()
if not device:
return False
@@ -269,7 +267,7 @@ def sync_roles(engine: Engine) -> None:
expected = {role.value for role in Role}
with Session(engine) as session:
existing = {record.name for record in session.execute(select(RoleRecord)).scalars().all()}
existing = set(session.scalars(select(RoleRecord.name)).all())
to_add = expected - existing
to_remove = existing - expected

View File

@@ -34,8 +34,9 @@ def main(config_file: Path) -> None:
logger.error(msg)
signal_alert(msg)
continue
get_snapshots_to_delete(dataset, get_count_lookup(config_file, dataset.name))
count_lookup = get_count_lookup(config_file, dataset.name)
logger.info(f"using {count_lookup} for {dataset.name}")
get_snapshots_to_delete(dataset, count_lookup)
except Exception:
logger.exception("snapshot_manager failed")
signal_alert("snapshot_manager failed")
@@ -99,6 +100,7 @@ def get_snapshots_to_delete(
"""
snapshots = dataset.get_snapshots()
logger.info(f"calculating snapshots for {dataset.name} to be deleted")
if not snapshots:
logger.info(f"{dataset.name} has no snapshots")
return

View File

@@ -1,9 +1,9 @@
{ inputs, ... }:
{ inputs, pkgs, ... }:
{
imports = [
"${inputs.self}/users/richie"
"${inputs.self}/common/global"
"${inputs.self}/common/optional/desktop.nix"
"${inputs.self}/users/math"
"${inputs.self}/common/optional/docker.nix"
"${inputs.self}/common/optional/scanner.nix"
"${inputs.self}/common/optional/steam.nix"
@@ -12,13 +12,17 @@
"${inputs.self}/common/optional/update.nix"
"${inputs.self}/common/optional/yubikey.nix"
"${inputs.self}/common/optional/zerotier.nix"
"${inputs.self}/common/optional/brain_substituter.nix"
"${inputs.self}/common/optional/nvidia.nix"
./hardware.nix
./syncthing.nix
./llms.nix
];
boot = {
kernelPackages = pkgs.linuxPackages_6_18;
zfs.package = pkgs.zfs_2_4;
};
networking = {
hostName = "bob";
hostId = "7c678a41";

View File

@@ -23,6 +23,7 @@
"magistral:24b"
"ministral-3:14b"
"nemotron-3-nano:30b"
"nemotron-3-nano:4b"
"nemotron-cascade-2:30b"
"qwen3-coder:30b"
"qwen3-embedding:0.6b"

View File

@@ -0,0 +1,11 @@
#!/bin/bash
# zpools
# storage
sudo zpool create -f -o ashift=12 -O acltype=posixacl -O atime=off -O dnodesize=auto -O xattr=sa -O compression=zstd -m /zfs/storage storage mirror
sudo zpool create -o ashift=12 -O acltype=posixacl -O atime=off -O dnodesize=auto -O xattr=sa -O compression=zstd -m /zfs/storage storage
# storage datasets
sudo zfs create storage/models -o recordsize=1M

View File

@@ -24,6 +24,6 @@ monthly = 0
["root_pool/models"]
15_min = 4
hourly = 24
hourly = 2
daily = 0
monthly = 0

View File

@@ -15,18 +15,20 @@ sudo zpool add storage -o ashift=12 logs mirror
sudo zpool create scratch -o ashift=12 -O acltype=posixacl -O atime=off -O dnodesize=auto -O xattr=sa -O compression=zstd -O encryption=aes-256-gcm -O keyformat=hex -O keylocation=file:///key -m /zfs/scratch
# media datasets
sudo zfs create media/temp -o sync=disabled -o redundant_metadata=none
sudo zfs create media/secure -o encryption=aes-256-gcm -o keyformat=hex -o keylocation=file:///root/zfs.key
sudo zfs create media/secure/docker -o compression=zstd-9
sudo zfs create media/secure/github-runners -o compression=zstd-9 -o sync=disabled
sudo zfs create media/secure/home_assistant -o compression=zstd-19
sudo zfs create media/secure/notes -o copies=2
sudo zfs create media/secure/postgres -o recordsize=16k -o primarycache=metadata
sudo zfs create media/secure/postgres -o mountpoint=/zfs/media/database/postgres -o recordsize=16k -o primarycache=metadata
sudo zfs create media/secure/postgres-wal -o mountpoint=/zfs/media/database/postgres-wal -o recordsize=32k -o primarycache=metadata -o special_small_blocks=32K -o compression=lz4 -o secondarycache=none -o logbias=latency
sudo zfs create media/secure/services -o compression=zstd-9
sudo zfs create media/secure/share -o mountpoint=/zfs/media/share -o exec=off
# scratch datasets
sudo zfs create scratch/kafka -o mountpoint=/zfs/scratch/kafka -o recordsize=1M
sudo zfs create scratch/transmission -o mountpoint=/zfs/scratch/transmission -o recordsize=16k -o sync=disabled
sudo zfs create scratch/transmission -o mountpoint=/zfs/scratch/transmission -o recordsize=16k -o sync=disabled -o redundant_metadata=none
# storage datasets
sudo zfs create storage/ollama -o recordsize=1M -o compression=zstd-19 -o sync=disabled

View File

@@ -5,6 +5,10 @@ in
{
networking.firewall.allowedTCPPorts = [ 5432 ];
# Symlink pg_wal to a ZFS dataset on the special (metadata) vdev for fast WAL writes
# this is required for systemd sandboxing
systemd.services.postgresql.serviceConfig.ReadWritePaths = [ "/zfs/media/database/postgres-wal" ];
services.postgresql = {
enable = true;
package = pkgs.postgresql_17_jit;
@@ -39,6 +43,10 @@ in
host postgres math ::1/128 trust
host postgres math 192.168.90.1/24 trust
local data_science_dev math trust
host data_science_dev math 127.0.0.1/32 trust
host data_science_dev math ::1/128 trust
host data_science_dev math 192.168.90.1/24 trust
'';
identMap = ''
@@ -110,6 +118,7 @@ in
}
];
ensureDatabases = [
"data_science_dev"
"hass"
"gitea"
"math"

View File

@@ -4,6 +4,7 @@ hourly = 24
daily = 0
monthly = 0
# root_pool
["root_pool/home"]
15_min = 8
hourly = 24
@@ -27,57 +28,96 @@ monthly = 0
hourly = 24
daily = 30
monthly = 6
# storage
["storage/ollama"]
15_min = 2
hourly = 0
daily = 0
monthly = 0
["storage/plex"]
["storage/secure"]
15_min = 0
hourly = 0
daily = 0
monthly = 0
["storage/secure/plex"]
15_min = 6
hourly = 2
daily = 1
monthly = 0
["media/plex"]
15_min = 6
hourly = 2
daily = 1
["storage/secure/transmission"]
15_min = 4
hourly = 0
daily = 0
monthly = 0
["media/notes"]
["storage/secure/secrets"]
15_min = 8
hourly = 24
daily = 30
monthly = 12
["media/docker"]
15_min = 3
hourly = 12
daily = 14
monthly = 2
["media/services"]
15_min = 3
hourly = 12
daily = 14
monthly = 2
["media/home_assistant"]
# media
["media/temp"]
15_min = 2
hourly = 0
daily = 0
monthly = 0
["media/secure"]
15_min = 0
hourly = 0
daily = 0
monthly = 0
["media/secure/plex"]
15_min = 6
hourly = 2
daily = 1
monthly = 0
["media/secure/postgres-wal"]
15_min = 4
hourly = 2
daily = 0
monthly = 0
["media/secure/postgres"]
15_min = 8
hourly = 24
daily = 7
monthly = 0
["media/secure/share"]
15_min = 4
hourly = 0
daily = 0
monthly = 0
["media/secure/github-runners"]
15_min = 6
hourly = 2
daily = 1
monthly = 0
["media/secure/notes"]
15_min = 8
hourly = 24
daily = 30
monthly = 12
["media/secure/docker"]
15_min = 3
hourly = 12
daily = 14
monthly = 2
# scratch
["scratch/transmission"]
15_min = 0
hourly = 0
daily = 0
monthly = 0
["storage/transmission"]
15_min = 0
hourly = 0
daily = 0
monthly = 0
["storage/ollama"]
15_min = 0
15_min = 2
hourly = 0
daily = 0
monthly = 0

View File

@@ -14,6 +14,7 @@
./llms.nix
./open_webui.nix
./qmk.nix
./sunshine.nix
./syncthing.nix
inputs.nixos-hardware.nixosModules.framework-13-7040-amd
];

Binary file not shown.

View File

@@ -8,9 +8,8 @@
"deepscaler:1.5b"
"deepseek-r1:8b"
"gemma3:12b"
"gemma3:27b"
"gpt-oss:20b"
"lfm2:24b"
"nemotron-3-nano:4b"
"qwen3:14b"
"qwen3.5:27b"
];

View File

@@ -0,0 +1,24 @@
{ pkgs, ... }:
{
services.sunshine = {
enable = true;
openFirewall = true;
capSysAdmin = true;
};
environment.systemPackages = [ pkgs.kdePackages.libkscreen ];
boot.kernelParams = [
"drm.edid_firmware=DP-4:edid/virtual-display.bin"
"video=DP-4:e"
];
hardware = {
firmwareCompression = "none";
firmware = [
(pkgs.runCommandLocal "virtual-display-edid" { } ''
mkdir -p $out/lib/firmware/edid
cp ${./edid/virtual-display.bin} $out/lib/firmware/edid/virtual-display.bin
'')
];
};
}

View File

@@ -210,9 +210,9 @@ class TestContactCache:
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
mock_device = MagicMock()
mock_device.trust_level = TrustLevel.UNVERIFIED
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_device
mock_session.scalars.return_value.one_or_none.return_value = mock_device
registry.record_contact("+1234", "abc")
mock_session.execute.assert_called_once()
mock_session.scalars.assert_called_once()
class TestLocationCommand:

View File

@@ -1,86 +0,0 @@
"""Count lines of code in the repository, grouped by file type."""
from __future__ import annotations
import subprocess
from collections import defaultdict
from pathlib import Path
def get_tracked_files() -> list[str]:
"""Get all git-tracked files."""
result = subprocess.run(
["git", "ls-files"], # noqa: S603, S607
capture_output=True,
text=True,
check=True,
)
return [f for f in result.stdout.strip().splitlines() if f]
def count_lines(filepath: str) -> int:
"""Count lines in a file, returning 0 for binary files."""
try:
return len(Path(filepath).read_text(encoding="utf-8").splitlines())
except (UnicodeDecodeError, OSError):
return 0
def count_lines_by_type() -> dict[str, int]:
"""Count lines grouped by file extension."""
lines_by_type: dict[str, int] = defaultdict(int)
for filepath in get_tracked_files():
ext = Path(filepath).suffix.lstrip(".")
if not ext:
ext = Path(filepath).name
lines_by_type[ext] += count_lines(filepath)
# Exclude binary/non-code files
for key in ("png", "lock"):
lines_by_type.pop(key, None)
return dict(sorted(lines_by_type.items(), key=lambda x: x[1], reverse=True))
def format_report() -> str:
"""Generate a formatted line count report."""
lines_by_type = count_lines_by_type()
total = sum(lines_by_type.values())
lines = [
f"This repo has **{total:,}** lines of technical debt.",
"",
"| File Type | Lines | Percentage |",
"|-----------|------:|-----------:|",
]
for ext, count in lines_by_type.items():
if count > 0:
pct = count / total * 100
prefix = "." if not ext.startswith(".") else ""
lines.append(f"| {prefix}{ext} | {count:,} | {pct:.1f}% |")
return "\n".join(lines)
def update_readme() -> None:
"""Update README.md with the line count report."""
readme_path = Path("README.md")
report = format_report()
start_marker = "<!-- LINE-COUNT-START -->"
end_marker = "<!-- LINE-COUNT-END -->"
content = readme_path.read_text(encoding="utf-8")
section = f"{start_marker}\n{report}\n{end_marker}"
if start_marker in content:
start = content.index(start_marker)
end = content.index(end_marker) + len(end_marker)
content = content[:start] + section + content[end:]
else:
content = content.rstrip() + "\n\n" + section + "\n"
readme_path.write_text(content, encoding="utf-8")
if __name__ == "__main__":
update_readme()

View File

@@ -0,0 +1,5 @@
{
imports = [
../home/global.nix
];
}

View File

@@ -20,15 +20,15 @@
// turns off all sounds and announcements
"accessibility.signals.terminalCommandFailed": {
"sound": "off",
"announcement": "off"
"announcement": "off",
},
"accessibility.signals.terminalQuickFix": {
"sound": "off",
"announcement": "off"
"announcement": "off",
},
"accessibility.signals.terminalBell": {
"sound": "off",
"announcement": "off"
"announcement": "off",
},
// database settings
@@ -41,8 +41,8 @@
"driver": "PostgreSQL",
"name": "main",
"database": "postgres",
"username": "richie"
}
"username": "richie",
},
],
// formatters
@@ -55,7 +55,7 @@
"[yaml]": { "editor.defaultFormatter": "redhat.vscode-yaml" },
"[javascriptreact]": { "editor.defaultFormatter": "esbenp.prettier-vscode" },
"[github-actions-workflow]": {
"editor.defaultFormatter": "redhat.vscode-yaml"
"editor.defaultFormatter": "redhat.vscode-yaml",
},
"[dockercompose]": {
"editor.insertSpaces": true,
@@ -64,9 +64,9 @@
"editor.quickSuggestions": {
"other": true,
"comments": false,
"strings": true
"strings": true,
},
"editor.defaultFormatter": "redhat.vscode-yaml"
"editor.defaultFormatter": "redhat.vscode-yaml",
},
// spell check
@@ -78,7 +78,9 @@
"Corvidae",
"drivername",
"fastapi",
"syncthing"
"Qwen",
"sandboxing",
"syncthing",
],
// nix
@@ -96,5 +98,6 @@
// new
"hediet.vscode-drawio.resizeImages": null,
"hediet.vscode-drawio.appearance": "automatic",
"claudeCode.preferredLocation": "panel"
"claudeCode.preferredLocation": "panel",
"docker.extension.enableComposeLanguageServer": false,
}

View File

@@ -1,6 +1,5 @@
{
imports = [
../home/global.nix
../home/gui
];
}