10 Commits

20 changed files with 4149 additions and 62 deletions
@@ -0,0 +1,211 @@
"""move bill text summaries into a child table.
Revision ID: 4b2e1c9d8f70
Revises: b9360b0b0c22
Create Date: 2026-05-03 00:00:00.000000
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from pipelines.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "4b2e1c9d8f70"
down_revision: str | None = "b9360b0b0c22"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
op.create_table(
"bill_text_summary",
sa.Column("bill_text_id", sa.Integer(), nullable=False),
sa.Column("summary", sa.String(), nullable=False),
sa.Column("summarization_model", sa.String(), nullable=True),
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["bill_text_id"],
[f"{schema}.bill_text.id"],
name=op.f("fk_bill_text_summary_bill_text_id_bill_text"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text_summary")),
schema=schema,
)
op.create_index(
"ix_bill_text_summary_bill_text_id",
"bill_text_summary",
["bill_text_id"],
unique=False,
schema=schema,
)
op.create_index(
"ix_bill_text_summary_bill_text_id_created",
"bill_text_summary",
["bill_text_id", "created"],
unique=False,
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("primary_summary_id", sa.Integer(), nullable=True),
schema=schema,
)
op.create_foreign_key(
op.f("fk_bill_text_primary_summary_id_bill_text_summary"),
"bill_text",
"bill_text_summary",
["primary_summary_id"],
["id"],
source_schema=schema,
referent_schema=schema,
ondelete="SET NULL",
)
op.execute(
sa.text(
f"""
INSERT INTO {schema}.bill_text_summary (
bill_text_id,
summary,
summarization_model,
summarization_user_prompt_version,
summarization_system_prompt_version,
created,
updated
)
SELECT
bill_text.id,
bill_text.summary,
bill_text.summarization_model,
bill_text.summarization_user_prompt_version,
bill_text.summarization_system_prompt_version,
COALESCE(bill_text.updated, bill_text.created, now()),
COALESCE(bill_text.updated, bill_text.created, now())
FROM {schema}.bill_text
WHERE bill_text.summary IS NOT NULL
AND btrim(bill_text.summary) <> ''
"""
)
)
op.drop_column("bill_text", "summary", schema=schema)
op.drop_column("bill_text", "summarization_model", schema=schema)
op.drop_column("bill_text", "summarization_user_prompt_version", schema=schema)
op.drop_column("bill_text", "summarization_system_prompt_version", schema=schema)
def downgrade() -> None:
"""Downgrade."""
op.add_column(
"bill_text",
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summarization_model", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summary", sa.String(), nullable=True),
schema=schema,
)
op.execute(
sa.text(
f"""
WITH ranked AS (
SELECT
bts.*,
row_number() OVER (
PARTITION BY bts.bill_text_id
ORDER BY bts.created DESC, bts.id DESC
) AS rn
FROM {schema}.bill_text_summary AS bts
),
chosen AS (
SELECT
bill_text.id AS bill_text_id,
COALESCE(ps.summary, ls.summary) AS summary,
COALESCE(
ps.summarization_model,
ls.summarization_model
) AS summarization_model,
COALESCE(
ps.summarization_user_prompt_version,
ls.summarization_user_prompt_version
) AS summarization_user_prompt_version,
COALESCE(
ps.summarization_system_prompt_version,
ls.summarization_system_prompt_version
) AS summarization_system_prompt_version
FROM {schema}.bill_text
LEFT JOIN {schema}.bill_text_summary AS ps
ON ps.id = bill_text.primary_summary_id
LEFT JOIN ranked AS ls
ON ls.bill_text_id = bill_text.id
AND ls.rn = 1
)
UPDATE {schema}.bill_text
SET
summary = chosen.summary,
summarization_model = chosen.summarization_model,
summarization_user_prompt_version = chosen.summarization_user_prompt_version,
summarization_system_prompt_version = chosen.summarization_system_prompt_version
FROM chosen
WHERE chosen.bill_text_id = bill_text.id
"""
)
)
op.drop_constraint(
op.f("fk_bill_text_primary_summary_id_bill_text_summary"),
"bill_text",
schema=schema,
type_="foreignkey",
)
op.drop_column("bill_text", "primary_summary_id", schema=schema)
op.drop_index(
"ix_bill_text_summary_bill_text_id_created",
table_name="bill_text_summary",
schema=schema,
)
op.drop_index(
"ix_bill_text_summary_bill_text_id",
table_name="bill_text_summary",
schema=schema,
)
op.drop_table("bill_text_summary", schema=schema)
+1 -1
View File
@@ -1 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM.""" """Init."""
+116
View File
@@ -0,0 +1,116 @@
"""Nornsight — BERTopic POC Inference Script.
Loads the trained model and labels a small batch of posts,
writing results to main.post_topic for inspection.
POC: processes a single batch of 1k posts to validate the pipeline end-to-end.
"""
from __future__ import annotations
import logging
import time
from collections import Counter
from pathlib import Path
from bertopic import BERTopic
from sqlalchemy import Engine, func, insert, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import PostTopic, Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Run BERTopic inference against a sample of posts."""
configure_logger()
config = get_bertopic_infer_config()
run_inference(config)
logger.info(
"POC inference complete. Check main.post_topic in DBeaver to inspect results."
)
def run_inference(config: BertTopicInferConfig) -> None:
model_save_path = Path(config.model_save_path)
logger.info(f"Loading BERTopic model from {model_save_path}")
topic_model = BERTopic.load(str(model_save_path))
topic_info = topic_model.get_topic_info()
label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"]))
logger.info(f"Model loaded with {len(label_map)} topics")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
post_ids, texts = get_post_ids_and_test(engine, config)
logger.info(f"Fetched {len(texts)} posts")
logger.info("Running BERTopic transform")
start = time.perf_counter()
topics, _probabilities = topic_model.transform(texts)
elapsed = time.perf_counter() - start
logger.info(f"Transform complete in {elapsed:.1f}s")
# Write results to main.post_topic
records = [
{
"post_id": pid,
"topic_id": int(topic_id),
"topic_label": label_map.get(int(topic_id), "unknown"),
"model_version": config.model_version,
}
for pid, topic_id in zip(post_ids, topics)
]
with Session(engine) as session:
session.execute(insert(PostTopic), records)
session.commit()
count_topics(records)
logger.info(f"Wrote {len(records)} topic labels to main.post_topic")
def get_post_ids_and_test(
engine: Engine,
config: BertTopicInferConfig,
) -> None | tuple[list[int], list[str]]:
with Session(engine) as session:
logger.info(f"Fetching {config.poc_batch_size} posts for inference")
# Pull a fresh batch for inference — distinct from training sample
# using a fixed seed offset so we're not re-labeling training posts
stmt = select(Posts).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
)
if config.poc_batch_size > 0:
stmt = stmt.limit(config.poc_batch_size)
posts = session.scalars(stmt).all()
if not posts:
logger.warning("No posts were selected for inference")
return [], []
post_ids = [post.post_id for post in posts]
texts = [post.text.strip() for post in posts]
return post_ids, texts
def count_topics(records: list[dict]) -> None:
topic_counts = Counter(record.get("topic_label", "unknown") for record in records)
logger.info("Topic distribution in this batch:")
for label, count in topic_counts.most_common(10):
logger.info(" %s: %d", label, count)
if __name__ == "__main__":
main()
+119
View File
@@ -0,0 +1,119 @@
"""Nornsight — BERTopic POC Training Script.
Pulls a small stratified sample (~11.5k posts) from main.posts,
trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally.
POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts)
Full training rate will be: random() < 0.005 (~1.08M posts)
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
from bertopic import BERTopic
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Train and persist the BERTopic model."""
configure_logger()
config = get_bertopic_train_config()
docs = load_sample(config)
if not docs:
logger.warning("No training documents were selected")
return
train(docs, config)
logger.info(f"Done. Model saved as version {config.model_version}")
logger.info("Next: run infer.py to label a sample of posts in the database")
def load_sample(config: BertTopicTrainConfig) -> list[str]:
logger.info("Connecting to PostgreSQL via SQLAlchemy")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})")
start = time.perf_counter()
with Session(engine) as session:
texts = session.scalars(
select(Posts.text).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
func.random() < config.sample_rate,
)
).all()
elapsed = time.perf_counter() - start
logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s")
# Basic cleaning — strip whitespace and deduplicate
docs = list({text.strip() for text in texts})
logger.info(f"After cleaning and dedup: {len(docs)} posts")
return docs
def train(docs: list[str], config: BertTopicTrainConfig) -> None:
logger.info(
f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})"
)
cluster_model = MiniBatchKMeans(
n_clusters=config.n_topics,
random_state=42,
batch_size=1024,
n_init=3,
verbose=1,
)
topic_model = BERTopic(
hdbscan_model=cluster_model,
language="english",
calculate_probabilities=False, # saves memory
verbose=True,
)
logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)")
start = time.perf_counter()
topic_model.fit_transform(docs)
elapsed = time.perf_counter() - start
logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
# Log topic summary for quick inspection
topic_info = topic_model.get_topic_info()
logger.info(f"Topics found: {len(topic_info)}")
logger.info(f"\n{topic_info.to_string()}")
model_save_path = Path(config.model_save_path)
model_save_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving model to {model_save_path}")
topic_model.save(
str(model_save_path),
serialization="safetensors",
save_ctfidf=True,
save_embedding_model=True,
)
logger.info("Model saved")
if __name__ == "__main__":
main()
+57
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from os import getenv from os import getenv
from datetime import date
from pathlib import Path from pathlib import Path
import tomllib import tomllib
@@ -50,6 +51,7 @@ class FinetuneConfig:
) )
@dataclass
class BenchmarkConfig: class BenchmarkConfig:
"""Top-level benchmark configuration loaded from TOML.""" """Top-level benchmark configuration loaded from TOML."""
@@ -101,6 +103,45 @@ class OpenAIConfig:
) )
@dataclass
class BertTopicTrainConfig:
"""BERTopic training configuration loaded from TOML."""
sample_rate: float
min_text_length: int
n_topics: int
model_save_path: str
model_version: str | None = None
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicTrainConfig:
"""Load BERTopic training config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["train"]
today = date.today().isoformat()
if raw.get("model_version") is None:
raw["model_version"] = (
f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}"
)
return cls(**raw)
@dataclass
class BertTopicInferConfig:
"""BERTopic inference configuration loaded from TOML."""
min_text_length: int
poc_batch_size: int
model_version: str
model_save_path: str
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicInferConfig:
"""Load BERTopic inference config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"]
return cls(**raw)
def get_config_dir() -> Path: def get_config_dir() -> Path:
"""Get the path to the config directory.""" """Get the path to the config directory."""
return Path(__file__).resolve().parents[2] / "config" return Path(__file__).resolve().parents[2] / "config"
@@ -127,3 +168,19 @@ def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
if config_path is None: if config_path is None:
config_path = default_config_path() config_path = default_config_path()
return BenchmarkConfig.from_toml(config_path) return BenchmarkConfig.from_toml(config_path)
def get_bertopic_train_config(
config_path: Path | None = None,
) -> BertTopicTrainConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicTrainConfig.from_toml(config_path)
def get_bertopic_infer_config(
config_path: Path | None = None,
) -> BertTopicInferConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicInferConfig.from_toml(config_path)
@@ -23,7 +23,7 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from pipelines.congress_vote_context import create_score_run, finalize_score_run from pipelines.jobs.congress_vote_context import create_score_run, finalize_score_run
from pipelines.orm.common import get_postgres_engine from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import ( from pipelines.orm.data_science_dev.congress import (
BillTopic, BillTopic,
@@ -39,7 +39,7 @@ from pipelines.orm.data_science_dev.congress import (
VoteRelationship, VoteRelationship,
VoteRecord, VoteRecord,
) )
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label from pipelines.jobs.extract_bill_topics import normalize_topic_label
from pipelines.web.scoring import ( from pipelines.web.scoring import (
OPPOSE_POSITIONS, OPPOSE_POSITIONS,
SUPPORT_POSITIONS, SUPPORT_POSITIONS,
File diff suppressed because it is too large Load Diff
+26 -12
View File
@@ -19,6 +19,7 @@ from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import ( from pipelines.orm.data_science_dev.congress import (
Bill, Bill,
BillText, BillText,
BillTextSummary,
BillTopic, BillTopic,
BillTopicPosition, BillTopicPosition,
SubjectType, SubjectType,
@@ -72,11 +73,19 @@ class ExtractedBillTopic:
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None: def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
"""Pick one summarized bill_text row from the already-loaded relationship.""" """Pick one summarized bill_text row from the already-loaded relationship."""
for bill_text in bill.bill_texts: for bill_text in bill.bill_texts:
if bill_text.summary and bill_text.summary.strip(): summary_row = bill_text.default_summary()
if summary_row and summary_row.summary.strip():
return bill_text return bill_text
return None return None
def _bill_text_has_summary_clause() -> ColumnElement[bool]:
"""Return a correlated EXISTS clause for bill texts with at least one summary."""
return exists(
select(BillTextSummary.id).where(BillTextSummary.bill_text_id == BillText.id)
)
def normalize_topic_label(value: str) -> str: def normalize_topic_label(value: str) -> str:
"""Normalize a topic label for storage, comparison, and de-duping.""" """Normalize a topic label for storage, comparison, and de-duping."""
normalized = value.strip().strip("\"'") normalized = value.strip().strip("\"'")
@@ -323,11 +332,7 @@ def create_select_bills_for_topic_extraction(
limit: int | None = None, limit: int | None = None,
) -> Select[tuple[Bill]]: ) -> Select[tuple[Bill]]:
"""Select bill rows that have summarized bill_text rows for topic extraction.""" """Select bill rows that have summarized bill_text rows for topic extraction."""
has_summary = (BillText.summary.is_not(None), BillText.summary != "") summarized_text_filters: list[ColumnElement[bool]] = [_bill_text_has_summary_clause()]
summarized_text_filters: list[ColumnElement[bool]] = [
BillText.bill_id == Bill.id,
*has_summary,
]
if with_votes_only: if with_votes_only:
summarized_text_filters.append( summarized_text_filters.append(
exists( exists(
@@ -347,11 +352,17 @@ def create_select_bills_for_topic_extraction(
) )
) )
) )
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters)) summarized_text_exists = exists(
select(BillText.id).where(BillText.bill_id == Bill.id, *summarized_text_filters)
)
bill_text_loader = selectinload(Bill.bill_texts.and_(*summarized_text_filters))
stmt = ( stmt = (
select(Bill) select(Bill)
.where(summarized_text_exists) .where(summarized_text_exists)
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:]))) .options(
bill_text_loader.selectinload(BillText.summaries),
bill_text_loader.selectinload(BillText.primary_summary),
)
.order_by(Bill.id) .order_by(Bill.id)
) )
if congress is not None: if congress is not None:
@@ -363,7 +374,7 @@ def create_select_bills_for_topic_extraction(
select(BillText.id).where( select(BillText.id).where(
BillText.bill_id == Bill.id, BillText.bill_id == Bill.id,
BillText.id.in_(bill_text_ids), BillText.id.in_(bill_text_ids),
*summarized_text_filters[1:], *summarized_text_filters,
) )
) )
stmt = stmt.where(selected_text_exists) stmt = stmt.where(selected_text_exists)
@@ -416,8 +427,7 @@ def collect_topic_extraction_diagnostics(
) )
) )
has_summary = (BillText.summary.is_not(None), BillText.summary != "") summary_filters = [*bill_text_filters, _bill_text_has_summary_clause()]
summary_filters = [*bill_text_filters, *has_summary]
bills_with_summaries = session.scalar( bills_with_summaries = session.scalar(
select(func.count(func.distinct(Bill.id))) select(func.count(func.distinct(Bill.id)))
@@ -607,7 +617,11 @@ def main(
if bill_text is None: if bill_text is None:
logger.warning("Skipping bill id=%s: no usable summary", bill.id) logger.warning("Skipping bill id=%s: no usable summary", bill.id)
continue continue
summary = bill_text.summary.strip() summary_row = bill_text.default_summary()
if summary_row is None:
logger.warning("Skipping bill id=%s: no default summary", bill.id)
continue
summary = summary_row.summary.strip()
try: try:
extracted_topics = extract_topics_for_bill_text( extracted_topics = extract_topics_for_bill_text(
File diff suppressed because it is too large Load Diff
+281
View File
@@ -0,0 +1,281 @@
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
Usage:
ingest-posts /path/to/files/
ingest-posts /path/to/single_file.jsonl
ingest-posts /data/dir/ --workers 4 --batch-size 5000
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from pathlib import Path # noqa: TC003 this is needed for typer
from typing import TYPE_CHECKING, Annotated
import orjson
import psycopg
import typer
from pipelines.pipelines.common import configure_logger
from pipelines.orm.common import get_connection_info
from pipelines.pipelines.parallelize import parallelize_process
if TYPE_CHECKING:
from collections.abc import Iterator
logger = logging.getLogger(__name__)
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
@app.command()
def main(
path: Annotated[
Path,
typer.Argument(help="Directory containing JSONL files, or a single JSONL file"),
],
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
workers: Annotated[
int, typer.Option(help="Parallel workers for multi-file ingestion")
] = 4,
pattern: Annotated[
str, typer.Option(help="Glob pattern for JSONL files")
] = "*.jsonl",
) -> None:
"""Ingest JSONL post files into the weekly-partitioned posts table."""
configure_logger(level="INFO")
logger.info("starting ingest-posts")
logger.info(
"path=%s batch_size=%d workers=%d pattern=%s",
path,
batch_size,
workers,
pattern,
)
if path.is_file():
ingest_file(path, batch_size=batch_size)
elif path.is_dir():
ingest_directory(
path, batch_size=batch_size, max_workers=workers, pattern=pattern
)
else:
typer.echo(f"Path does not exist: {path}", err=True)
raise typer.Exit(code=1)
logger.info("ingest-posts done")
def ingest_directory(
directory: Path,
*,
batch_size: int,
max_workers: int,
pattern: str = "*.jsonl",
) -> None:
"""Ingest all JSONL files in a directory using parallel workers."""
files = sorted(directory.glob(pattern))
if not files:
logger.warning("No JSONL files found in %s", directory)
return
logger.info("Found %d JSONL files to ingest", len(files))
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
SCHEMA = "main"
COLUMNS = (
"post_id",
"user_id",
"instance",
"date",
"text",
"langs",
"like_count",
"reply_count",
"repost_count",
"reply_to",
"replied_author",
"thread_root",
"thread_root_author",
"repost_from",
"reposted_author",
"quotes",
"quoted_author",
"labels",
"sent_label",
"sent_score",
)
INSERT_FROM_STAGING = f"""
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
ON CONFLICT (post_id, date) DO NOTHING
""" # noqa: S608
FAILED_INSERT = f"""
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
VALUES (%(raw_line)s, %(error)s)
""" # noqa: S608
def get_psycopg_connection() -> psycopg.Connection:
"""Create a raw psycopg3 connection from environment variables."""
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
return psycopg.connect(
dbname=database,
host=host,
port=int(port),
user=username,
password=password,
autocommit=False,
)
def ingest_file(path: Path, *, batch_size: int) -> None:
"""Ingest a single JSONL file into the posts table."""
log_trigger = max(100_000 // batch_size, 1)
failed_lines: list[dict] = []
try:
with get_psycopg_connection() as connection:
for index, batch in enumerate(
read_jsonl_batches(path, batch_size, failed_lines), 1
):
ingest_batch(connection, batch)
if index % log_trigger == 0:
logger.info(
"Ingested %d batches (%d rows) from %s",
index,
index * batch_size,
path,
)
if failed_lines:
logger.warning(
"Recording %d malformed lines from %s", len(failed_lines), path.name
)
with connection.cursor() as cursor:
cursor.executemany(FAILED_INSERT, failed_lines)
connection.commit()
except Exception:
logger.exception("Failed to ingest file: %s", path)
raise
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
if not batch:
return
try:
with connection.cursor() as cursor:
cursor.execute(f"""
CREATE TEMP TABLE IF NOT EXISTS staging
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
ON COMMIT DELETE ROWS
""")
cursor.execute("TRUNCATE pg_temp.staging")
with cursor.copy(
f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN"
) as copy:
for row in batch:
copy.write_row(tuple(row.get(column) for column in COLUMNS))
cursor.execute(INSERT_FROM_STAGING)
connection.commit()
except Exception as error:
connection.rollback()
if len(batch) == 1:
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
with connection.cursor() as cursor:
cursor.execute(
FAILED_INSERT,
{
"raw_line": orjson.dumps(batch[0], default=str).decode(),
"error": str(error),
},
)
connection.commit()
return
midpoint = len(batch) // 2
ingest_batch(connection, batch[:midpoint])
ingest_batch(connection, batch[midpoint:])
def read_jsonl_batches(
file_path: Path, batch_size: int, failed_lines: list[dict]
) -> Iterator[list[dict]]:
"""Stream a JSONL file and yield batches of transformed rows."""
batch: list[dict] = []
with file_path.open("r", encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line:
continue
batch.extend(parse_line(line, file_path, failed_lines))
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
"""Parse a JSONL line, handling concatenated JSON objects."""
try:
yield transform_row(orjson.loads(line))
except orjson.JSONDecodeError:
if "}{" not in line:
logger.warning(
"Skipping malformed line in %s: %s", file_path.name, line[:120]
)
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
return
fragments = line.replace("}{", "}\n{").split("\n")
for fragment in fragments:
try:
yield transform_row(orjson.loads(fragment))
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
logger.warning(
"Skipping malformed fragment in %s: %s",
file_path.name,
fragment[:120],
)
failed_lines.append({"raw_line": fragment, "error": str(error)})
except Exception as error:
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
failed_lines.append({"raw_line": line, "error": str(error)})
def transform_row(raw: dict) -> dict:
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
raw["date"] = parse_date(raw["date"])
if raw.get("langs") is not None:
raw["langs"] = orjson.dumps(raw["langs"])
if raw.get("text") is not None:
raw["text"] = raw["text"].replace("\x00", "")
return raw
def parse_date(raw_date: int) -> datetime:
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
return datetime(
raw_date // 100000000,
(raw_date // 1000000) % 100,
(raw_date // 10000) % 100,
(raw_date // 100) % 100,
raw_date % 100,
tzinfo=UTC,
)
if __name__ == "__main__":
app()
+23 -9
View File
@@ -9,7 +9,7 @@ from typing import Annotated, Any
import httpx import httpx
import typer import typer
from sqlalchemy import Select, exists, or_, select from sqlalchemy import Select, exists, select
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding from tiktoken import get_encoding
@@ -20,6 +20,7 @@ from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import ( from pipelines.orm.data_science_dev.congress import (
Bill, Bill,
BillText, BillText,
BillTextSummary,
SubjectType, SubjectType,
VoteClassification, VoteClassification,
VoteRelationship, VoteRelationship,
@@ -112,7 +113,7 @@ def summarize_bill_text(
model: str, model: str,
bill_text: BillText, bill_text: BillText,
summarization_prompts: dict[str, str], summarization_prompts: dict[str, str],
) -> str: ) -> str | None:
"""Generate and return a summary for one bill_text row.""" """Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages( messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text, bill_text=bill_text,
@@ -136,15 +137,21 @@ def summarize_bill_text(
def store_bill_summary_result( def store_bill_summary_result(
*, *,
session: Session,
bill_text: BillText, bill_text: BillText,
summary: str, summary: str,
model: str, model: str,
) -> None: ) -> BillTextSummary:
"""Store a generated summary and the prompt/model metadata that produced it.""" """Store a generated summary and the prompt/model metadata that produced it."""
bill_text.summary = summary summary_row = BillTextSummary(
bill_text.summarization_model = model bill_text=bill_text,
bill_text.summarization_system_prompt_version = "v1.2" summary=summary,
bill_text.summarization_user_prompt_version = "v1" summarization_model=model,
summarization_system_prompt_version="v1.2",
summarization_user_prompt_version="v1",
)
session.add(summary_row)
return summary_row
def create_select_bill_texts_for_summarization( def create_select_bill_texts_for_summarization(
@@ -154,7 +161,7 @@ def create_select_bill_texts_for_summarization(
with_votes_only: bool = False, with_votes_only: bool = False,
force: bool = False, force: bool = False,
limit: int | None = None, limit: int | None = None,
) -> Select: ) -> Select[tuple[BillText]]:
"""Select bill_text rows that have source text and need summaries.""" """Select bill_text rows that have source text and need summaries."""
stmt = ( stmt = (
select(BillText) select(BillText)
@@ -189,7 +196,13 @@ def create_select_bill_texts_for_summarization(
) )
) )
if not force: if not force:
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == "")) stmt = stmt.where(
~exists(
select(BillTextSummary.id).where(
BillTextSummary.bill_text_id == BillText.id
)
)
)
if limit is not None: if limit is not None:
stmt = stmt.limit(limit) stmt = stmt.limit(limit)
return stmt return stmt
@@ -287,6 +300,7 @@ def main(
logger.warning("Skipping bill_text id=%s", bill_text.id) logger.warning("Skipping bill_text id=%s", bill_text.id)
continue continue
store_bill_summary_result( store_bill_summary_result(
session=session,
bill_text=bill_text, bill_text=bill_text,
summary=summary, summary=summary,
model=model, model=model,
@@ -6,6 +6,7 @@ from pipelines.orm.data_science_dev.congress.bill import (
BillActionRecordedVote, BillActionRecordedVote,
BillRelation, BillRelation,
BillText, BillText,
BillTextSummary,
BillTopic, BillTopic,
BillTopicPosition, BillTopicPosition,
) )
@@ -54,6 +55,7 @@ __all__ = [
"BillActionRecordedVote", "BillActionRecordedVote",
"BillRelation", "BillRelation",
"BillText", "BillText",
"BillTextSummary",
"BillTopic", "BillTopic",
"BillTopicPosition", "BillTopicPosition",
"ClassificationMethod", "ClassificationMethod",
@@ -105,13 +105,12 @@ class BillText(DataScienceDevTableBase):
) )
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE")) bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
primary_summary_id: Mapped[int | None] = mapped_column(
ForeignKey("main.bill_text_summary.id", ondelete="SET NULL")
)
version_code: Mapped[str] version_code: Mapped[str]
version_name: Mapped[str | None] version_name: Mapped[str | None]
text_content: Mapped[str | None] text_content: Mapped[str | None]
summary: Mapped[str | None]
summarization_model: Mapped[str | None]
summarization_user_prompt_version: Mapped[str | None]
summarization_system_prompt_version: Mapped[str | None]
date: Mapped[date | None] date: Mapped[date | None]
source_datetime_raw: Mapped[str | None] source_datetime_raw: Mapped[str | None]
text_url_xml: Mapped[str | None] text_url_xml: Mapped[str | None]
@@ -122,6 +121,57 @@ class BillText(DataScienceDevTableBase):
) )
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts") bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
summaries: Mapped[list[BillTextSummary]] = relationship(
"BillTextSummary",
back_populates="bill_text",
cascade="all, delete-orphan",
foreign_keys="BillTextSummary.bill_text_id",
order_by=lambda: (
BillTextSummary.created.desc(),
BillTextSummary.id.desc(),
),
)
primary_summary: Mapped[BillTextSummary | None] = relationship(
"BillTextSummary",
foreign_keys=[primary_summary_id],
post_update=True,
)
def latest_summary(self) -> BillTextSummary | None:
"""Return the newest summary row for this bill text."""
return self.summaries[0] if self.summaries else None
def default_summary(self) -> BillTextSummary | None:
"""Return the primary summary when set, otherwise the newest summary."""
return self.primary_summary or self.latest_summary()
class BillTextSummary(DataScienceDevTableBase):
"""Stores one generated summary for a bill text version."""
__tablename__ = "bill_text_summary"
__table_args__ = (
Index("ix_bill_text_summary_bill_text_id", "bill_text_id"),
Index(
"ix_bill_text_summary_bill_text_id_created",
"bill_text_id",
"created",
),
)
bill_text_id: Mapped[int] = mapped_column(
ForeignKey("main.bill_text.id", ondelete="CASCADE")
)
summary: Mapped[str]
summarization_model: Mapped[str | None]
summarization_user_prompt_version: Mapped[str | None]
summarization_system_prompt_version: Mapped[str | None]
bill_text: Mapped[BillText] = relationship(
"BillText",
back_populates="summaries",
foreign_keys=[bill_text_id],
)
class BillAction(DataScienceDevTableBase): class BillAction(DataScienceDevTableBase):
+2
View File
@@ -11,6 +11,7 @@ from pipelines.orm.data_science_dev.congress import (
BillActionRecordedVote, BillActionRecordedVote,
BillRelation, BillRelation,
BillText, BillText,
BillTextSummary,
BillTopic, BillTopic,
BillTopicPosition, BillTopicPosition,
ClassificationMethod, ClassificationMethod,
@@ -51,6 +52,7 @@ __all__ = [
"BillActionRecordedVote", "BillActionRecordedVote",
"BillRelation", "BillRelation",
"BillText", "BillText",
"BillTextSummary",
"BillTopic", "BillTopic",
"BillTopicPosition", "BillTopicPosition",
"ClassificationMethod", "ClassificationMethod",
-34
View File
@@ -1,34 +0,0 @@
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action including secondary effects that might be buried in subsections.
EXTRACTION RULES:
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
- FOCUS ON: operative verbs what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
OUTPUT FORMAT plain structured text, not JSON:
OPERATIVE ACTIONS:
[Numbered list of what the bill actually does, one action per line, max 20 words each]
AFFECTED POPULATIONS:
[Who gains something, who loses something, or whose behavior is regulated]
MECHANISMS:
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
POLICY THREADS:
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
SYMBOLIC/PROCEDURAL ONLY:
[Yes or No is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
BILL TEXT:
{text_content}"""
View File
+22
View File
@@ -0,0 +1,22 @@
[project]
name = "ds-testing-pipelines"
version = "0.1.0"
description = "Data science pipeline tools and legislative dashboard."
requires-python = ">=3.12"
dependencies = [
"fastapi",
"httpx",
"uvicorn[standard]",
"jinja2",
"sqlalchemy",
"psycopg",
]
[project.optional-dependencies]
test = [
"pytest",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
+36
View File
@@ -0,0 +1,36 @@
from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary
def test_default_summary_prefers_primary_summary() -> None:
primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary")
latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[latest_summary],
primary_summary=primary_summary,
)
assert bill_text.default_summary() is primary_summary
def test_default_summary_falls_back_to_latest_summary() -> None:
latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest")
older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[latest_summary, older_summary],
)
assert bill_text.latest_summary() is latest_summary
assert bill_text.default_summary() is latest_summary
def test_default_summary_is_none_without_summaries() -> None:
bill_text = BillText(id=10, bill_id=5, version_code="ih")
assert bill_text.latest_summary() is None
assert bill_text.default_summary() is None
+71
View File
@@ -0,0 +1,71 @@
from sqlalchemy.dialects import postgresql
from pipelines.jobs.extract_bill_topics import (
_select_bill_text_for_topic_extraction,
create_select_bills_for_topic_extraction,
)
from pipelines.orm.data_science_dev.congress import Bill, BillText, BillTextSummary
def _compile_sql(statement: object) -> str:
return str(
statement.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
def test_select_bill_text_for_topic_extraction_uses_primary_summary() -> None:
primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary")
newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[newest_summary],
primary_summary=primary_summary,
)
bill = Bill(
id=5,
congress=119,
bill_type="hr",
number=1,
bill_texts=[bill_text],
)
selected = _select_bill_text_for_topic_extraction(bill)
assert selected is bill_text
assert selected.default_summary() is primary_summary
def test_select_bill_text_for_topic_extraction_uses_latest_summary_without_primary() -> None:
newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest")
older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[newest_summary, older_summary],
)
bill = Bill(
id=5,
congress=119,
bill_type="hr",
number=1,
bill_texts=[bill_text],
)
selected = _select_bill_text_for_topic_extraction(bill)
assert selected is bill_text
assert selected.default_summary() is newest_summary
def test_create_select_bills_for_topic_extraction_uses_summary_exists_subquery() -> None:
sql = _compile_sql(create_select_bills_for_topic_extraction())
assert "bill_text_summary" in sql
assert "EXISTS" in sql
assert "bill_text.summary" not in sql
+58
View File
@@ -0,0 +1,58 @@
from sqlalchemy.dialects import postgresql
from pipelines.jobs.summarize_bills import (
create_select_bill_texts_for_summarization,
store_bill_summary_result,
)
from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary
class FakeSession:
def __init__(self) -> None:
self.added: list[object] = []
def add(self, value: object) -> None:
self.added.append(value)
def _compile_sql(statement: object) -> str:
return str(
statement.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
def test_store_bill_summary_result_creates_summary_row() -> None:
session = FakeSession()
bill_text = BillText(id=10, bill_id=5, version_code="ih")
summary_row = store_bill_summary_result(
session=session,
bill_text=bill_text,
summary="A summary",
model="gpt-5.4-mini",
)
assert session.added == [summary_row]
assert isinstance(summary_row, BillTextSummary)
assert summary_row.bill_text is bill_text
assert summary_row.summary == "A summary"
assert summary_row.summarization_model == "gpt-5.4-mini"
assert summary_row.summarization_system_prompt_version == "v1.2"
assert summary_row.summarization_user_prompt_version == "v1"
def test_create_select_bill_texts_for_summarization_excludes_existing_summaries() -> None:
sql = _compile_sql(create_select_bill_texts_for_summarization(force=False))
assert "bill_text_summary" in sql
assert "NOT (EXISTS" in sql or "NOT EXISTS" in sql
assert "bill_text.summary" not in sql
def test_create_select_bill_texts_for_summarization_force_skips_summary_filter() -> None:
sql = _compile_sql(create_select_bill_texts_for_summarization(force=True))
assert "bill_text_summary" not in sql