4 Commits

19 changed files with 1861 additions and 334 deletions
-57
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from pathlib import Path
import tomllib
@@ -69,50 +68,10 @@ class BenchmarkConfig:
return cls(**raw)
@dataclass
class BertTopicTrainConfig:
"""BERTopic training configuration loaded from TOML."""
sample_rate: float
min_text_length: int
n_topics: int
model_save_path: str
model_version: str | None = None
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicTrainConfig:
"""Load BERTopic training config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["train"]
today = date.today().isoformat()
if raw.get("model_version") is None:
raw["model_version"] = (
f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}"
)
return cls(**raw)
@dataclass
class BertTopicInferConfig:
"""BERTopic inference configuration loaded from TOML."""
min_text_length: int
poc_batch_size: int
model_version: str
model_save_path: str
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicInferConfig:
"""Load BERTopic inference config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"]
return cls(**raw)
def get_config_dir() -> Path:
"""Get the path to the config file."""
return Path(__file__).resolve().parent.parent.parent / "config"
def default_config_path() -> Path:
"""Get the path to the config file."""
return get_config_dir() / "config.toml"
@@ -128,19 +87,3 @@ def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
if config_path is None:
config_path = default_config_path()
return BenchmarkConfig.from_toml(config_path)
def get_bertopic_train_config(
config_path: Path | None = None,
) -> BertTopicTrainConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicTrainConfig.from_toml(config_path)
def get_bertopic_infer_config(
config_path: Path | None = None,
) -> BertTopicInferConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicInferConfig.from_toml(config_path)
-116
View File
@@ -1,116 +0,0 @@
"""Nornsight — BERTopic POC Inference Script.
Loads the trained model and labels a small batch of posts,
writing results to main.post_topic for inspection.
POC: processes a single batch of 1k posts to validate the pipeline end-to-end.
"""
from __future__ import annotations
import logging
import time
from collections import Counter
from pathlib import Path
from bertopic import BERTopic
from sqlalchemy import Engine, func, insert, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import PostTopic, Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Run BERTopic inference against a sample of posts."""
configure_logger()
config = get_bertopic_infer_config()
run_inference(config)
logger.info(
"POC inference complete. Check main.post_topic in DBeaver to inspect results."
)
def run_inference(config: BertTopicInferConfig) -> None:
model_save_path = Path(config.model_save_path)
logger.info(f"Loading BERTopic model from {model_save_path}")
topic_model = BERTopic.load(str(model_save_path))
topic_info = topic_model.get_topic_info()
label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"]))
logger.info(f"Model loaded with {len(label_map)} topics")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
post_ids, texts = get_post_ids_and_test(engine, config)
logger.info(f"Fetched {len(texts)} posts")
logger.info("Running BERTopic transform")
start = time.perf_counter()
topics, _probabilities = topic_model.transform(texts)
elapsed = time.perf_counter() - start
logger.info(f"Transform complete in {elapsed:.1f}s")
# Write results to main.post_topic
records = [
{
"post_id": pid,
"topic_id": int(topic_id),
"topic_label": label_map.get(int(topic_id), "unknown"),
"model_version": config.model_version,
}
for pid, topic_id in zip(post_ids, topics)
]
with Session(engine) as session:
session.execute(insert(PostTopic), records)
session.commit()
count_topics(records)
logger.info(f"Wrote {len(records)} topic labels to main.post_topic")
def get_post_ids_and_test(
engine: Engine,
config: BertTopicInferConfig,
) -> None | tuple[list[int], list[str]]:
with Session(engine) as session:
logger.info(f"Fetching {config.poc_batch_size} posts for inference")
# Pull a fresh batch for inference — distinct from training sample
# using a fixed seed offset so we're not re-labeling training posts
stmt = select(Posts).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
)
if config.poc_batch_size > 0:
stmt = stmt.limit(config.poc_batch_size)
posts = session.scalars(stmt).all()
if not posts:
logger.warning("No posts were selected for inference")
return [], []
post_ids = [post.post_id for post in posts]
texts = [post.text.strip() for post in posts]
return post_ids, texts
def count_topics(records: list[dict]) -> None:
topic_counts = Counter(record.get("topic_label", "unknown") for record in records)
logger.info("Topic distribution in this batch:")
for label, count in topic_counts.most_common(10):
logger.info(" %s: %d", label, count)
if __name__ == "__main__":
main()
-119
View File
@@ -1,119 +0,0 @@
"""Nornsight — BERTopic POC Training Script.
Pulls a small stratified sample (~11.5k posts) from main.posts,
trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally.
POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts)
Full training rate will be: random() < 0.005 (~1.08M posts)
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
from bertopic import BERTopic
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Train and persist the BERTopic model."""
configure_logger()
config = get_bertopic_train_config()
docs = load_sample(config)
if not docs:
logger.warning("No training documents were selected")
return
train(docs, config)
logger.info(f"Done. Model saved as version {config.model_version}")
logger.info("Next: run infer.py to label a sample of posts in the database")
def load_sample(config: BertTopicTrainConfig) -> list[str]:
logger.info("Connecting to PostgreSQL via SQLAlchemy")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})")
start = time.perf_counter()
with Session(engine) as session:
texts = session.scalars(
select(Posts.text).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
func.random() < config.sample_rate,
)
).all()
elapsed = time.perf_counter() - start
logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s")
# Basic cleaning — strip whitespace and deduplicate
docs = list({text.strip() for text in texts})
logger.info(f"After cleaning and dedup: {len(docs)} posts")
return docs
def train(docs: list[str], config: BertTopicTrainConfig) -> None:
logger.info(
f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})"
)
cluster_model = MiniBatchKMeans(
n_clusters=config.n_topics,
random_state=42,
batch_size=1024,
n_init=3,
verbose=1,
)
topic_model = BERTopic(
hdbscan_model=cluster_model,
language="english",
calculate_probabilities=False, # saves memory
verbose=True,
)
logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)")
start = time.perf_counter()
topic_model.fit_transform(docs)
elapsed = time.perf_counter() - start
logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
# Log topic summary for quick inspection
topic_info = topic_model.get_topic_info()
logger.info(f"Topics found: {len(topic_info)}")
logger.info(f"\n{topic_info.to_string()}")
model_save_path = Path(config.model_save_path)
model_save_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving model to {model_save_path}")
topic_model.save(
str(model_save_path),
serialization="safetensors",
save_ctfidf=True,
save_embedding_model=True,
)
logger.info("Model saved")
if __name__ == "__main__":
main()
+1
View File
@@ -0,0 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
+235
View File
@@ -0,0 +1,235 @@
"""Docker container lifecycle management for BERTopic jobs on Jeeves."""
from __future__ import annotations
import logging
import os
import subprocess
from pathlib import Path
from typing import Annotated, Literal
import typer
logger = logging.getLogger(__name__)
JOBMode = Literal["train", "infer"]
IMAGE_NAME = "bert-topic:latest"
REPO_DIR = Path(__file__).resolve().parents[3]
DEFAULT_CACHE_ROOT = Path("/zfs/storage/main/ds_thing/models/bert_topic")
DEFAULT_POSTGRES_SOCKET_DIR = Path("/run/postgresql")
DB_ENV_VARS = (
"DATA_SCIENCE_DEV_DB",
"DATA_SCIENCE_DEV_HOST",
"DATA_SCIENCE_DEV_PORT",
"DATA_SCIENCE_DEV_USER",
"DATA_SCIENCE_DEV_PASSWORD",
)
app = typer.Typer(help="BERTopic container management.")
def _container_name(mode: JOBMode) -> str:
"""Return the Docker container name for the selected BERTopic job."""
return f"bert-topic-{mode}"
def _module_name(mode: JOBMode) -> str:
"""Return the Python module to run inside the container."""
return f"pipelines.bert_topic.{mode}"
def _env_args(*, use_postgres_socket: bool) -> list[str]:
"""Pass through database environment variables from the host shell."""
required = [
"DATA_SCIENCE_DEV_DB",
"DATA_SCIENCE_DEV_PORT",
"DATA_SCIENCE_DEV_USER",
]
if not use_postgres_socket:
required.append("DATA_SCIENCE_DEV_HOST")
missing = [name for name in required if not os.getenv(name)]
if missing:
message = "Missing required database environment variables: " + ", ".join(
missing
)
raise RuntimeError(message)
args: list[str] = []
if use_postgres_socket:
args.extend(["-e", f"DATA_SCIENCE_DEV_HOST={DEFAULT_POSTGRES_SOCKET_DIR}"])
for name in DB_ENV_VARS:
if use_postgres_socket and name == "DATA_SCIENCE_DEV_HOST":
continue
if os.getenv(name):
args.extend(["-e", name])
return args
def build_image() -> None:
"""Build the BERTopic Docker image."""
dockerfile = REPO_DIR / "pipelines/containers/docker_files/Dockerfile.bert_topic"
logger.info("Building BERTopic image: %s", IMAGE_NAME)
result = subprocess.run(
[
"docker",
"build",
"--network",
"host",
"-f",
str(dockerfile),
"-t",
IMAGE_NAME,
str(REPO_DIR),
],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
message = (
"Failed to build BERTopic image. "
f"docker build stderr:\n{result.stderr.strip()}"
)
raise RuntimeError(message)
logger.info("Image built: %s", IMAGE_NAME)
def stop_job(*, mode: JOBMode) -> None:
"""Stop and remove the BERTopic container for the selected mode."""
container_name = _container_name(mode)
logger.info("Stopping BERTopic container: %s", container_name)
subprocess.run(["docker", "stop", container_name], capture_output=True, check=False)
subprocess.run(
["docker", "rm", "-f", container_name], capture_output=True, check=False
)
def start_job(
*,
mode: JOBMode,
cache_root: Path = DEFAULT_CACHE_ROOT,
postgres_socket_dir: Path = DEFAULT_POSTGRES_SOCKET_DIR,
detach: bool = False,
) -> None:
"""Run BERTopic training or inference in Docker on Jeeves."""
cache_root = cache_root.resolve()
cache_root.mkdir(parents=True, exist_ok=True)
postgres_socket_dir = postgres_socket_dir.resolve()
stop_job(mode=mode)
use_postgres_socket = postgres_socket_dir.exists()
command = [
"docker",
"run",
"--name",
_container_name(mode),
"--ipc=host",
"-v",
f"{cache_root}:/cache",
*_env_args(use_postgres_socket=use_postgres_socket),
IMAGE_NAME,
_module_name(mode),
]
if use_postgres_socket:
command[7:7] = ["-v", f"{postgres_socket_dir}:{DEFAULT_POSTGRES_SOCKET_DIR}"]
if detach:
command.insert(2, "-d")
logger.info("Starting BERTopic %s container", mode)
logger.info(" Cache root: %s", cache_root)
if use_postgres_socket:
logger.info(" Postgres socket: %s", postgres_socket_dir)
result = subprocess.run(command, text=True, capture_output=detach, check=False)
if result.returncode != 0:
detail = (
result.stderr.strip() if result.stderr else f"exit code {result.returncode}"
)
raise RuntimeError(f"BERTopic container failed to start: {detail}")
if detach:
logger.info("Container started: %s", result.stdout.strip()[:12])
else:
logger.info("BERTopic %s run complete", mode)
def logs_job(*, mode: JOBMode) -> str | None:
"""Return recent logs from the BERTopic container, or None if absent."""
result = subprocess.run(
["docker", "logs", "--tail", "100", _container_name(mode)],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
return None
return result.stdout + result.stderr
@app.command()
def build(
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Build the BERTopic Docker image."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
build_image()
@app.command("run")
def run_job_command(
mode: Annotated[JOBMode, typer.Option(help="Which BERTopic job to run")] = "train",
cache_root: Annotated[
Path, typer.Option(help="Host path mounted to /cache for model and HF cache")
] = DEFAULT_CACHE_ROOT,
postgres_socket_dir: Annotated[
Path, typer.Option(help="Host Postgres socket directory to mount into the container")
] = DEFAULT_POSTGRES_SOCKET_DIR,
detach: Annotated[
bool, typer.Option(help="Start the container in the background")
] = False,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run BERTopic training or inference inside Docker."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
start_job(
mode=mode,
cache_root=cache_root,
postgres_socket_dir=postgres_socket_dir,
detach=detach,
)
@app.command("stop")
def stop_job_command(
mode: Annotated[
JOBMode, typer.Option(help="Which BERTopic container to stop")
] = "train",
) -> None:
"""Stop and remove the BERTopic container."""
stop_job(mode=mode)
@app.command("logs")
def logs_job_command(
mode: Annotated[
JOBMode, typer.Option(help="Which BERTopic container logs to show")
] = "train",
) -> None:
"""Show recent logs from the BERTopic container."""
output = logs_job(mode=mode)
if output is None:
typer.echo(f"No BERTopic container found for mode={mode}.")
raise typer.Exit(code=1)
typer.echo(output)
def cli() -> None:
"""Typer entry point."""
app()
if __name__ == "__main__":
cli()
@@ -0,0 +1,38 @@
FROM python:3.12-bookworm
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PIP_NO_CACHE_DIR=1
RUN apt-get update && apt-get install -y \
build-essential \
gcc \
g++ \
git \
libgomp1 \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY pipelines ./pipelines
RUN python -m pip install --upgrade pip setuptools wheel && \
python -m pip install \
torch \
--index-url https://download.pytorch.org/whl/cpu && \
python -m pip install \
typer \
sqlalchemy \
bertopic \
sentence-transformers \
scikit-learn \
pandas \
numpy \
"psycopg[binary]"
ENV HF_HOME=/cache/huggingface
ENV TRANSFORMERS_CACHE=/cache/huggingface
ENTRYPOINT ["python", "-m"]
CMD ["pipelines.bert_topic.train"]
@@ -0,0 +1,11 @@
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer
WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
@@ -9,7 +9,7 @@ from typing import Annotated
import typer
from pipelines.tools.containers.lib import check_gpu_free
from pipelines.pipelines.containers.lib import check_gpu_free
logger = logging.getLogger(__name__)
@@ -27,7 +27,7 @@ def build_image() -> None:
"docker",
"build",
"-f",
str(REPO_DIR / "python/prompt_bench/Dockerfile.finetune"),
str(REPO_DIR / "pipelines/containers/docker_files/Dockerfile.finetune"),
"-t",
FINETUNE_IMAGE,
".",
@@ -0,0 +1,574 @@
"""Calculate legislator topic scores from bill topics and roll-call votes."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Annotated, Sequence
import typer
from sqlalchemy import (
ColumnElement,
Integer,
Select,
and_,
case,
cast,
delete,
extract,
func,
or_,
select,
tuple_,
)
from sqlalchemy.orm import Session
from pipelines.congress_vote_context import create_score_run, finalize_score_run
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
BillTopic,
BillTopicPosition,
LegislatorScore,
SubjectType,
Vote,
VoteClassification,
VoteEffect,
VoteMeasureLink,
VoteMeasureRole,
VotePositionMeaning,
VoteRelationship,
VoteRecord,
)
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
from pipelines.web.scoring import (
OPPOSE_POSITIONS,
SUPPORT_POSITIONS,
normalized_position_expression,
)
logger = logging.getLogger(__name__)
DELETE_BATCH_SIZE = 5_000
@dataclass(frozen=True)
class ScoreDiagnostics:
"""Counts for the input stages required to calculate legislator scores."""
bill_topic_rows: int
linked_vote_rows: int
vote_record_rows: int
topic_vote_links: int
scorable_vote_records: int
@dataclass(frozen=True)
class LegislatorScoreInput:
"""One aggregated score ready to store in legislator_score."""
legislator_id: int
year: int
topic: str
score: float
supportive: int
opposed: int
def create_legislator_score_query(
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> Select:
"""Build the aggregate score query from extracted bill topics and vote records."""
normalized_vote = normalized_position_expression(VoteRecord.position)
supportive_vote = _supportive_vote_expression(normalized_vote)
opposed_vote = _opposed_vote_expression(normalized_vote)
supportive_count = func.sum(supportive_vote)
opposed_count = func.sum(opposed_vote)
total_count = supportive_count + opposed_count
vote_year = cast(extract("year", Vote.vote_date), Integer)
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
stmt = (
select(
VoteRecord.legislator_id.label("legislator_id"),
vote_year.label("year"),
BillTopic.topic.label("topic"),
score,
supportive_count.label("supportive"),
opposed_count.label("opposed"),
total_count.label("total"),
)
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*_eligible_vote_filters(),
_is_scorable_position(normalized_vote),
)
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
.having(total_count > 0)
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
)
if congress is not None:
stmt = stmt.where(Vote.congress == congress)
if bill_ids:
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
normalized_topics = _normalize_topics(topics)
if normalized_topics:
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
return stmt
def collect_legislator_scores(
session: Session,
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> list[LegislatorScoreInput]:
"""Run the aggregate query and return score rows."""
rows = session.execute(
create_legislator_score_query(
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
)
return [
LegislatorScoreInput(
legislator_id=int(row.legislator_id),
year=int(row.year),
topic=str(row.topic),
score=float(row.score),
supportive=int(row.supportive),
opposed=int(row.opposed),
)
for row in rows
if row.score is not None
]
def collect_score_diagnostics(
session: Session,
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> ScoreDiagnostics:
"""Count score pipeline inputs for explaining empty score runs."""
normalized_topics = _normalize_topics(topics)
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
normalized_vote = normalized_position_expression(VoteRecord.position)
eligible_vote_filters = _eligible_vote_filters()
bill_topic_rows = session.scalar(
select(func.count(BillTopic.id)).where(*topic_filters)
)
linked_vote_rows = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(Vote)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.where(*vote_filters, *eligible_vote_filters)
)
vote_record_rows = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.where(*vote_filters, *eligible_vote_filters)
)
topic_vote_links = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
)
scorable_vote_records = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*vote_filters,
*topic_filters,
*eligible_vote_filters,
_is_scorable_position(normalized_vote),
)
)
return ScoreDiagnostics(
bill_topic_rows=bill_topic_rows or 0,
linked_vote_rows=linked_vote_rows or 0,
vote_record_rows=vote_record_rows or 0,
topic_vote_links=topic_vote_links or 0,
scorable_vote_records=scorable_vote_records or 0,
)
def store_legislator_scores(
session: Session,
rows: Sequence[LegislatorScoreInput],
*,
score_run_id: int | None,
replace_all: bool = False,
) -> int:
"""Replace matching score rows and insert the newly calculated scores."""
if replace_all:
session.execute(delete(LegislatorScore))
elif rows:
keys = [
(row.legislator_id, row.year, row.topic)
for row in rows
]
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
session.execute(
delete(LegislatorScore).where(
tuple_(
LegislatorScore.legislator_id,
LegislatorScore.year,
LegislatorScore.topic,
).in_(key_batch)
)
)
session.add_all(
[
LegislatorScore(
legislator_id=row.legislator_id,
year=row.year,
topic=row.topic,
score=row.score,
score_run_id=score_run_id,
)
for row in rows
]
)
return len(rows)
def _supportive_vote_expression(
normalized_vote: ColumnElement[str | None],
) -> ColumnElement[int]:
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
return case(
(
and_(
BillTopic.support_position == BillTopicPosition.FOR,
supports_text,
),
1,
),
(
and_(
BillTopic.support_position == BillTopicPosition.AGAINST,
opposes_text,
),
1,
),
else_=0,
)
def _opposed_vote_expression(
normalized_vote: ColumnElement[str | None],
) -> ColumnElement[int]:
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
return case(
(
and_(
BillTopic.support_position == BillTopicPosition.FOR,
opposes_text,
),
1,
),
(
and_(
BillTopic.support_position == BillTopicPosition.AGAINST,
supports_text,
),
1,
),
else_=0,
)
def _position_matches_effect(
normalized_vote: ColumnElement[str | None],
effect: VoteEffect,
) -> ColumnElement[bool]:
return or_(
and_(
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
VotePositionMeaning.yea_effect == effect,
),
and_(
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
VotePositionMeaning.nay_effect == effect,
),
and_(
normalized_vote == "present",
VotePositionMeaning.present_effect == effect,
),
)
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
return or_(
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
)
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
normalized: list[str] = []
seen: set[str] = set()
for topic in topics or []:
value = normalize_topic_label(topic)
if value and value not in seen:
normalized.append(value)
seen.add(value)
return normalized
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
return [
items[index : index + batch_size]
for index in range(0, len(items), batch_size)
]
def _vote_scope_filters(
*,
congress: int | None,
bill_ids: Sequence[int] | None,
) -> list[ColumnElement[bool]]:
filters: list[ColumnElement[bool]] = []
if congress is not None:
filters.append(Vote.congress == congress)
if bill_ids:
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
return filters
def _topic_scope_filters(
*,
bill_ids: Sequence[int] | None,
topics: Sequence[str],
) -> list[ColumnElement[bool]]:
filters: list[ColumnElement[bool]] = []
if bill_ids:
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
if topics:
filters.append(BillTopic.topic.in_(list(topics)))
return filters
def _has_score_scope(
*,
congress: int | None,
bill_ids: Sequence[int] | None,
topics: Sequence[str] | None,
) -> bool:
return congress is not None or bool(bill_ids) or bool(topics)
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
return [
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
]
def main(
congress: Annotated[
int | None,
typer.Option(help="Only score votes from one Congress."),
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only score votes linked to one internal bill.id. Repeatable.",
),
] = None,
topics: Annotated[
list[str] | None,
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
] = None,
replace_all: Annotated[
bool,
typer.Option(
help="Delete every existing legislator score before inserting. "
"Unfiltered runs do this automatically."
),
] = False,
dry_run: Annotated[
bool,
typer.Option(help="Calculate scores without writing to the database."),
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
diagnose: Annotated[
bool,
typer.Option(help="Log input-stage counts even when rows are calculated."),
] = False,
) -> None:
"""CLI entrypoint for calculating and storing legislator topic scores."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
rows = collect_legislator_scores(
session,
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
logger.info("Calculated %d legislator topic score rows", len(rows))
if diagnose or not rows:
diagnostics = collect_score_diagnostics(
session,
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
_log_diagnostics(diagnostics)
if dry_run:
session.rollback()
return
score_run = create_score_run(session)
should_replace_all = replace_all or not _has_score_scope(
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
written = store_legislator_scores(
session,
rows,
score_run_id=score_run.id,
replace_all=should_replace_all,
)
included_vote_count = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
*_eligible_vote_filters(),
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
)
) or 0
total_scoped_votes = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(Vote)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
) or 0
finalize_score_run(
session,
score_run=score_run,
included_vote_count=included_vote_count,
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
)
session.commit()
logger.info("Stored %d legislator topic score rows", written)
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
logger.info(
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
diagnostics.bill_topic_rows,
diagnostics.linked_vote_rows,
diagnostics.vote_record_rows,
diagnostics.topic_vote_links,
diagnostics.scorable_vote_records,
)
if diagnostics.bill_topic_rows == 0:
logger.warning(
"No extracted bill topics matched the score scope. Run "
"pipelines.tools.extract_bill_topics after bill summarization."
)
elif diagnostics.linked_vote_rows == 0:
logger.warning("No direct substantive text votes matched the score scope.")
elif diagnostics.vote_record_rows == 0:
logger.warning("No individual vote records matched the score scope.")
elif diagnostics.topic_vote_links == 0:
logger.warning(
"Bill topics exist, but none are attached to bills that have eligible scored votes."
)
elif diagnostics.scorable_vote_records == 0:
logger.warning(
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
)
if __name__ == "__main__":
typer.run(main)
+682
View File
@@ -0,0 +1,682 @@
"""Extract bill topics from bill text using a configurable topic catalog."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Any, Sequence
import httpx
import typer
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
from sqlalchemy.orm import Session, selectinload
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTopic,
BillTopicPosition,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
logger = logging.getLogger(__name__)
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
REQUEST_TIMEOUT_SECONDS = 60
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
class TopicExtractionError(RuntimeError):
"""Raised when a topic extraction request or response is invalid."""
@dataclass(frozen=True)
class TopicCatalog:
"""Loaded topic catalog with categories for prompting and flat candidates."""
topics_by_category: dict[str, list[str]]
candidate_topics: list[str]
@dataclass(frozen=True)
class TopicExtractionDiagnostics:
"""Counts for the bill summary inputs needed by topic extraction."""
bill_rows: int
bill_text_rows: int
summarized_bill_text_rows: int
bills_with_summaries: int
bill_topic_rows: int
selected_bills: int
@dataclass(frozen=True)
class ExtractedBillTopic:
"""One extracted bill topic and yes-vote stance."""
topic: str
support_position: BillTopicPosition
confidence: float | None = None
evidence: str | None = None
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
"""Pick one summarized bill_text row from the already-loaded relationship."""
for bill_text in bill.bill_texts:
if bill_text.summary and bill_text.summary.strip():
return bill_text
return None
def normalize_topic_label(value: str) -> str:
"""Normalize a topic label for storage, comparison, and de-duping."""
normalized = value.strip().strip("\"'")
normalized = normalized.strip().rstrip(".").strip()
return re.sub(r"\s+", " ", normalized).lower()
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
"""Load, validate, normalize, and flatten the bill topic catalog."""
topics_path = path or DEFAULT_TOPICS_PATH
try:
raw = json.loads(topics_path.read_text())
except FileNotFoundError as exc:
msg = f"Topic catalog not found: {topics_path}"
raise TopicExtractionError(msg) from exc
except json.JSONDecodeError as exc:
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(raw, dict):
msg = "Topic catalog root must be an object mapping category names to lists"
raise TopicExtractionError(msg)
topics_by_category: dict[str, list[str]] = {}
candidate_topics: list[str] = []
seen_topics: set[str] = set()
for category, topics in raw.items():
if not isinstance(category, str) or not category.strip():
msg = "Topic catalog category names must be non-empty strings"
raise TopicExtractionError(msg)
if not isinstance(topics, list):
msg = f"Topic catalog category {category!r} must contain a list"
raise TopicExtractionError(msg)
normalized_topics: list[str] = []
for topic in topics:
if not isinstance(topic, str):
msg = f"Topic catalog category {category!r} contains a non-string topic"
raise TopicExtractionError(msg)
normalized_topic = normalize_topic_label(topic)
if not normalized_topic:
msg = f"Topic catalog category {category!r} contains a blank topic"
raise TopicExtractionError(msg)
if normalized_topic in seen_topics:
continue
seen_topics.add(normalized_topic)
normalized_topics.append(normalized_topic)
candidate_topics.append(normalized_topic)
topics_by_category[category.strip()] = normalized_topics
return TopicCatalog(
topics_by_category=topics_by_category,
candidate_topics=candidate_topics,
)
def build_topic_extraction_messages(
*,
bill: Bill,
bill_text: str,
candidate_topics: Sequence[str],
) -> list[dict[str, str]]:
"""Build GPT messages for extracting a bill's scored topics."""
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
metadata = "\n".join(
(
f"Congress: {bill.congress}",
f"Bill: {bill.bill_type} {bill.number}",
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
f"Top subject term: {bill.subjects_top_term or ''}",
)
)
system_prompt = (
"You extract policy topics from U.S. congressional bills.\n"
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
"Select only topics from the provided candidate topic list.\n"
"Omit topics that are not materially addressed by the bill.\n"
"Return strict JSON only, with this shape:\n"
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
)
user_prompt = "\n\n".join(
(
"BILL METADATA:",
metadata,
"CANDIDATE TOPICS:",
candidate_list,
"BILL TEXT:",
bill_text,
)
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def call_openai_topic_extraction(
*,
openai_config: OpenAIConfig,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
response = httpx.post(
openai_config.openai_chat_completions_url,
headers={
"Authorization": f"Bearer {openai_config.api_key}",
"OpenAI-Project": openai_config.openai_project_id,
"Content-Type": "application/json",
},
json={
"model": "gpt-5.4-mini",
"messages": messages,
},
timeout=openai_config.timeout_seconds,
)
response.raise_for_status()
return extract_message_content(response.json())
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise TopicExtractionError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise TopicExtractionError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise TopicExtractionError(msg)
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
payload = _load_json_response(response_text)
topics = payload.get("topics")
if not isinstance(topics, list):
msg = "Topic extraction response must contain a topics list"
raise TopicExtractionError(msg)
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
for item in topics:
if not isinstance(item, dict):
msg = "Topic extraction response topics must be objects"
raise TopicExtractionError(msg)
raw_topic = _extract_topic_label(item)
topic = normalize_topic_label(raw_topic)
if not topic:
msg = "Topic extraction response topic must not be blank"
raise TopicExtractionError(msg)
raw_position = item.get("support_position")
try:
support_position = BillTopicPosition(raw_position)
except ValueError as exc:
msg = f"Invalid support_position: {raw_position!r}"
raise TopicExtractionError(msg) from exc
confidence = _parse_confidence(item.get("confidence"))
evidence = item.get("evidence")
if evidence is not None and not isinstance(evidence, str):
evidence = str(evidence)
extracted = ExtractedBillTopic(
topic=topic,
support_position=support_position,
confidence=confidence,
evidence=evidence,
)
key = (topic, support_position)
existing = deduped.get(key)
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
deduped[key] = extracted
return list(deduped.values())
def extract_topics_for_bill_text(
*,
openai_config: OpenAIConfig,
bill: Bill,
text: str,
candidate_topics: Sequence[str],
) -> list[ExtractedBillTopic]:
"""Extract accepted catalog topics for a bill text string."""
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
messages = build_topic_extraction_messages(
bill=bill,
bill_text=text,
candidate_topics=sorted(normalized_candidates),
)
response_text = call_openai_topic_extraction(
openai_config=openai_config,
messages=messages,
)
extracted_topics = parse_topic_extraction_response(response_text)
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
def store_bill_topic_result(
*,
session: Session,
bill: Bill,
topics: Sequence[ExtractedBillTopic],
replace_existing: bool = True,
) -> None:
"""Store extracted topics for one bill."""
if replace_existing:
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
for topic in topics:
session.add(
BillTopic(
bill_id=bill.id,
topic=normalize_topic_label(topic.topic),
support_position=topic.support_position,
)
)
def create_select_bills_for_topic_extraction(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select[tuple[Bill]]:
"""Select bill rows that have summarized bill_text rows for topic extraction."""
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summarized_text_filters: list[ColumnElement[bool]] = [
BillText.bill_id == Bill.id,
*has_summary,
]
if with_votes_only:
summarized_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
stmt = (
select(Bill)
.where(summarized_text_exists)
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
.order_by(Bill.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(Bill.id.in_(bill_ids))
if bill_text_ids:
selected_text_exists = exists(
select(BillText.id).where(
BillText.bill_id == Bill.id,
BillText.id.in_(bill_text_ids),
*summarized_text_filters[1:],
)
)
stmt = stmt.where(selected_text_exists)
if not force:
stmt = stmt.where(
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def collect_topic_extraction_diagnostics(
session: Session,
*,
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> TopicExtractionDiagnostics:
"""Count topic extraction inputs for explaining empty selections."""
bill_filters = []
bill_text_filters: list[ColumnElement[bool]] = []
if congress is not None:
bill_filters.append(Bill.congress == congress)
if bill_ids:
bill_filters.append(Bill.id.in_(bill_ids))
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
bill_text_filters.append(BillText.id.in_(bill_text_ids))
if with_votes_only:
bill_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summary_filters = [*bill_text_filters, *has_summary]
bills_with_summaries = session.scalar(
select(func.count(func.distinct(Bill.id)))
.select_from(Bill)
.join(BillText, BillText.bill_id == Bill.id)
.where(*bill_filters, *summary_filters)
)
selected_bills = session.scalar(
select(func.count()).select_from(
create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
).subquery()
)
)
return TopicExtractionDiagnostics(
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=bill_text_filters,
),
summarized_bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=summary_filters,
),
bills_with_summaries=bills_with_summaries or 0,
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
selected_bills=selected_bills or 0,
)
def _load_json_response(response_text: str) -> dict[str, Any]:
text = response_text.strip()
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
if fenced:
text = fenced.group(1).strip()
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
msg = f"Topic extraction response is not valid JSON: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(payload, dict):
msg = "Topic extraction response must be a JSON object"
raise TopicExtractionError(msg)
return payload
def _parse_confidence(raw: Any) -> float | None:
if raw is None:
return None
try:
return float(raw)
except (TypeError, ValueError) as exc:
msg = f"Invalid confidence: {raw!r}"
raise TopicExtractionError(msg) from exc
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
if topic.confidence is None:
return (0, 0.0)
return (1, topic.confidence)
def _extract_topic_label(item: dict[str, Any]) -> str:
raw_topic = item.get("topic")
if isinstance(raw_topic, str):
return raw_topic
if isinstance(raw_topic, dict):
for key in ("topic", "label", "name", "title"):
value = raw_topic.get(key)
if isinstance(value, str):
return value
msg = "Topic extraction response topic must be a string"
raise TopicExtractionError(msg)
def _count_bill_texts(
session: Session,
*,
bill_filters: Sequence[ColumnElement[bool]],
bill_text_filters: Sequence[ColumnElement[bool]],
) -> int:
stmt = select(func.count(BillText.id))
if bill_filters:
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
return session.scalar(stmt.where(*bill_text_filters)) or 0
def main(
topics_path: Annotated[
Path, typer.Option(help="Path to congressional issue topic JSON.")
] = DEFAULT_TOPICS_PATH,
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process summarized bill_text rows linked to at least one vote.",
),
] = True,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate topics for bills that already have topics."),
] = False,
dry_run: Annotated[
bool,
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
] = False,
diagnose: Annotated[
bool,
typer.Option(help="Log input-stage counts before processing."),
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill topics."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
topic_catalog = load_topic_catalog(topics_path)
logger.info(
"Loaded %d candidate topics from %s",
len(topic_catalog.candidate_topics),
topics_path,
)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
if diagnose or dry_run:
diagnostics = collect_topic_extraction_diagnostics(
session,
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
_log_topic_extraction_diagnostics(diagnostics)
if dry_run:
return
openai_config = get_openai_config()
stmt = create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bills = session.scalars(stmt).all()
logger.info("Selected %d bills for topic extraction", len(bills))
written = 0
failed = 0
for index, bill in enumerate(bills, 1):
bill_text = _select_bill_text_for_topic_extraction(bill)
if bill_text is None:
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
continue
summary = bill_text.summary.strip()
try:
extracted_topics = extract_topics_for_bill_text(
openai_config=openai_config,
bill=bill,
text=summary,
candidate_topics=topic_catalog.candidate_topics,
)
except (httpx.HTTPError, TopicExtractionError):
failed += 1
logger.exception(
"Skipping bill id=%s after topic extraction failure", bill.id
)
continue
store_bill_topic_result(
session=session,
bill=bill,
topics=extracted_topics,
replace_existing=True,
)
written += 1
if index % 100 == 0:
session.commit()
logger.info(
"Stored %d topics for bill id=%s",
len(extracted_topics),
bill.id,
)
session.commit()
logger.info(
"Done: stored topic results for %d bills; failed %d bills",
written,
failed,
)
def _log_topic_extraction_diagnostics(
diagnostics: TopicExtractionDiagnostics,
) -> None:
logger.info(
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
"summarized_bill_text_rows=%d bills_with_summaries=%d "
"bill_topic_rows=%d selected_bills=%d",
diagnostics.bill_rows,
diagnostics.bill_text_rows,
diagnostics.summarized_bill_text_rows,
diagnostics.bills_with_summaries,
diagnostics.bill_topic_rows,
diagnostics.selected_bills,
)
if diagnostics.bill_rows == 0:
logger.warning("No bills matched the topic extraction scope.")
elif diagnostics.bill_text_rows == 0:
logger.warning("No bill_text rows matched the topic extraction scope.")
elif diagnostics.summarized_bill_text_rows == 0:
logger.warning(
"No summarized bill_text rows matched the topic extraction scope. "
"Run pipelines.tools.summarize_bills first."
)
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
logger.warning(
"No bills selected because matching bills already have topics. "
"Use --force to regenerate them."
)
elif diagnostics.selected_bills == 0:
logger.warning("No bills selected for topic extraction.")
if __name__ == "__main__":
typer.run(main)
+309
View File
@@ -0,0 +1,309 @@
"""Summarize bill_text rows with GPT-5 and store results in the database."""
from __future__ import annotations
import logging
import tomllib
from os import getenv
from typing import Annotated, Any
import httpx
import typer
from sqlalchemy import Select, exists, or_, select
from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding
from pipelines.config import get_config_dir
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
from pipelines.tools.bill_token_compression import compress_bill_text
logger = logging.getLogger(__name__)
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
REQUEST_TIMEOUT_SECONDS = 60
def load_summarization_prompts(
section: str = "summarization",
) -> dict[str, str]:
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
return tomllib.loads(summarization_prompts.read_text())[section]
class BillSummaryError(RuntimeError):
"""Raised when a bill summary request or response is invalid."""
def call_openai_summary(
*,
model: str,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
api_key = getenv("CLOSEDAI_TOKEN")
if not api_key:
msg = "CLOSEDAI_TOKEN is required"
raise BillSummaryError(msg)
response = httpx.post(
OPENAI_CHAT_COMPLETIONS_URL,
headers={
"Authorization": f"Bearer {api_key}",
"OpenAI-Project": OPENAI_PROJECT_ID,
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
},
timeout=REQUEST_TIMEOUT_SECONDS,
)
logger.info(f"{response.text=}")
response.raise_for_status()
return extract_message_content(response.json())
def build_bill_summary_messages(
*,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> list[dict[str, str]]:
"""Build the GPT prompt messages plus compressed text and user prompt."""
if not bill_text.text_content:
msg = f"bill_text id={bill_text.id} has no text_content"
raise BillSummaryError(msg)
compressed_text = compress_bill_text(bill_text.text_content)
if not compressed_text:
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
raise BillSummaryError(msg)
user_prompt = summarization_prompts["user_template"].format(
text_content=compressed_text
)
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
logger.info(f"{user_prompt_tokens=}")
messages = [
{"role": "system", "content": summarization_prompts["system_prompt"]},
{
"role": "user",
"content": user_prompt,
},
]
return messages, user_prompt_tokens
def summarize_bill_text(
*,
model: str,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> str:
"""Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
# This may only be for gpt-5.4 mini I need to read the docs
if user_prompt_tokens > 272000:
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
logger.warning(msg)
return None
summary = call_openai_summary(
model=model,
messages=messages,
).strip()
if not summary:
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
raise BillSummaryError(msg)
return summary
def store_bill_summary_result(
*,
bill_text: BillText,
summary: str,
model: str,
) -> None:
"""Store a generated summary and the prompt/model metadata that produced it."""
bill_text.summary = summary
bill_text.summarization_model = model
bill_text.summarization_system_prompt_version = "v1.2"
bill_text.summarization_user_prompt_version = "v1"
def create_select_bill_texts_for_summarization(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select:
"""Select bill_text rows that have source text and need summaries."""
stmt = (
select(BillText)
.join(Bill, Bill.id == BillText.bill_id)
.where(BillText.text_content.is_not(None), BillText.text_content != "")
.options(selectinload(BillText.bill))
.order_by(BillText.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
stmt = stmt.where(BillText.id.in_(bill_text_ids))
if with_votes_only:
stmt = stmt.where(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
if not force:
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise BillSummaryError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise BillSummaryError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise BillSummaryError(msg)
def main(
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process bill_text rows linked to at least one vote.",
),
] = False,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate summaries for rows that already have a summary."),
] = False,
dry_run: Annotated[
bool, typer.Option(help="Print summaries without writing them to the database.")
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill summaries."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
if not getenv("CLOSEDAI_TOKEN"):
message = "CLOSEDAI_TOKEN is required"
raise typer.BadParameter(message)
summarization_prompts = load_summarization_prompts()
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
stmt = create_select_bill_texts_for_summarization(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bill_texts = session.scalars(stmt).all()
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
written = 0
for index, bill_text in enumerate(bill_texts, 1):
summary = summarize_bill_text(
model=model,
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
if summary is None:
logger.warning("Skipping bill_text id=%s", bill_text.id)
continue
store_bill_summary_result(
bill_text=bill_text,
summary=summary,
model=model,
)
if index % 100 == 0:
session.commit()
written += 1
session.commit()
logger.info("Stored summary for bill_text id=%s", bill_text.id)
logger.info("Done: stored %d summaries", written)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
-25
View File
@@ -1,25 +0,0 @@
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
#
# Build:
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
#
# Run:
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
# -v $(pwd)/output:/workspace/output \
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
# -v /zfs/models/hf:/models \
# bill-finetune \
# --dataset /workspace/dataset.jsonl \
# --output-dir /workspace/output/qwen-bill-summarizer
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer
WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
+2 -6
View File
@@ -23,14 +23,10 @@ import httpx
import typer
from tiktoken import Encoding, get_encoding
from pipelines.config import get_config_dir
from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
+2 -6
View File
@@ -24,14 +24,10 @@ from typing import Annotated
import httpx
import typer
from pipelines.config import get_config_dir
from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
+3 -1
View File
@@ -25,6 +25,8 @@ from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from pipelines.config import default_config_path
logger = logging.getLogger(__name__)
@@ -123,7 +125,7 @@ def main(
config_path: Annotated[
Path,
typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml",
] = default_config_path(),
save_gguf: Annotated[
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
] = False,
+2 -2
View File
@@ -11,8 +11,8 @@ from typing import Annotated
import typer
from pipelines.tools.containers.lib import check_gpu_free
from pipelines.tools.containers.vllm import start_vllm, stop_vllm
from pipelines.containers.lib import check_gpu_free
from pipelines.containers.vllm import start_vllm, stop_vllm
from pipelines.tools.downloader import is_model_present
from pipelines.tools.models import BenchmarkConfig
from pipelines.tools.vllm_client import VLLMClient