Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 36718bbce0 | |||
| f33a5c2233 | |||
| 2facb82bd4 | |||
| 8d5a6e202b | |||
| f32c895561 | |||
| 09f7f0187f |
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
import tomllib
|
||||
|
||||
@@ -68,15 +69,54 @@ class BenchmarkConfig:
|
||||
return cls(**raw)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIConfig:
|
||||
"""OpenAI API configuration."""
|
||||
|
||||
api_key: str
|
||||
openai_project_id: str
|
||||
openai_chat_completions_url: str
|
||||
model: str
|
||||
timeout_seconds: int
|
||||
|
||||
@classmethod
|
||||
def from_toml(cls, config_path: Path) -> OpenAIConfig:
|
||||
"""Load OpenAI config from a TOML file."""
|
||||
raw = tomllib.loads(config_path.read_text()).get("openai", {})
|
||||
api_key = getenv("CLOSEDAI_TOKEN")
|
||||
if not api_key:
|
||||
message = "CLOSEDAI_TOKEN is required"
|
||||
raise KeyError(message)
|
||||
return cls(
|
||||
api_key=api_key,
|
||||
openai_project_id=raw.get(
|
||||
"openai_project_id", "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
),
|
||||
openai_chat_completions_url=raw.get(
|
||||
"openai_chat_completions_url",
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
),
|
||||
model=raw.get("model", "gpt-5.4-mini"),
|
||||
timeout_seconds=raw.get("timeout_seconds", 60),
|
||||
)
|
||||
|
||||
|
||||
def get_config_dir() -> Path:
|
||||
"""Get the path to the config file."""
|
||||
return Path(__file__).resolve().parent.parent.parent / "config"
|
||||
"""Get the path to the config directory."""
|
||||
return Path(__file__).resolve().parents[2] / "config"
|
||||
|
||||
|
||||
def default_config_path() -> Path:
|
||||
"""Get the path to the config file."""
|
||||
return get_config_dir() / "config.toml"
|
||||
|
||||
|
||||
def get_openai_config(config_path: Path | None = None) -> OpenAIConfig:
|
||||
if config_path is None:
|
||||
config_path = default_config_path()
|
||||
return OpenAIConfig.from_toml(config_path)
|
||||
|
||||
|
||||
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
||||
if config_path is None:
|
||||
config_path = default_config_path()
|
||||
@@ -0,0 +1 @@
|
||||
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||
@@ -0,0 +1,235 @@
|
||||
"""Docker container lifecycle management for BERTopic jobs on Jeeves."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import typer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JOBMode = Literal["train", "infer"]
|
||||
IMAGE_NAME = "bert-topic:latest"
|
||||
REPO_DIR = Path(__file__).resolve().parents[3]
|
||||
DEFAULT_CACHE_ROOT = Path("/zfs/storage/main/ds_thing/models/bert_topic")
|
||||
DEFAULT_POSTGRES_SOCKET_DIR = Path("/run/postgresql")
|
||||
DB_ENV_VARS = (
|
||||
"DATA_SCIENCE_DEV_DB",
|
||||
"DATA_SCIENCE_DEV_HOST",
|
||||
"DATA_SCIENCE_DEV_PORT",
|
||||
"DATA_SCIENCE_DEV_USER",
|
||||
"DATA_SCIENCE_DEV_PASSWORD",
|
||||
)
|
||||
|
||||
app = typer.Typer(help="BERTopic container management.")
|
||||
|
||||
|
||||
def _container_name(mode: JOBMode) -> str:
|
||||
"""Return the Docker container name for the selected BERTopic job."""
|
||||
return f"bert-topic-{mode}"
|
||||
|
||||
|
||||
def _module_name(mode: JOBMode) -> str:
|
||||
"""Return the Python module to run inside the container."""
|
||||
return f"pipelines.bert_topic.{mode}"
|
||||
|
||||
|
||||
def _env_args(*, use_postgres_socket: bool) -> list[str]:
|
||||
"""Pass through database environment variables from the host shell."""
|
||||
required = [
|
||||
"DATA_SCIENCE_DEV_DB",
|
||||
"DATA_SCIENCE_DEV_PORT",
|
||||
"DATA_SCIENCE_DEV_USER",
|
||||
]
|
||||
if not use_postgres_socket:
|
||||
required.append("DATA_SCIENCE_DEV_HOST")
|
||||
missing = [name for name in required if not os.getenv(name)]
|
||||
if missing:
|
||||
message = "Missing required database environment variables: " + ", ".join(
|
||||
missing
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
args: list[str] = []
|
||||
if use_postgres_socket:
|
||||
args.extend(["-e", f"DATA_SCIENCE_DEV_HOST={DEFAULT_POSTGRES_SOCKET_DIR}"])
|
||||
for name in DB_ENV_VARS:
|
||||
if use_postgres_socket and name == "DATA_SCIENCE_DEV_HOST":
|
||||
continue
|
||||
if os.getenv(name):
|
||||
args.extend(["-e", name])
|
||||
return args
|
||||
|
||||
|
||||
def build_image() -> None:
|
||||
"""Build the BERTopic Docker image."""
|
||||
dockerfile = REPO_DIR / "pipelines/containers/docker_files/Dockerfile.bert_topic"
|
||||
logger.info("Building BERTopic image: %s", IMAGE_NAME)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"build",
|
||||
"--network",
|
||||
"host",
|
||||
"-f",
|
||||
str(dockerfile),
|
||||
"-t",
|
||||
IMAGE_NAME,
|
||||
str(REPO_DIR),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
message = (
|
||||
"Failed to build BERTopic image. "
|
||||
f"docker build stderr:\n{result.stderr.strip()}"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
logger.info("Image built: %s", IMAGE_NAME)
|
||||
|
||||
|
||||
def stop_job(*, mode: JOBMode) -> None:
|
||||
"""Stop and remove the BERTopic container for the selected mode."""
|
||||
container_name = _container_name(mode)
|
||||
logger.info("Stopping BERTopic container: %s", container_name)
|
||||
subprocess.run(["docker", "stop", container_name], capture_output=True, check=False)
|
||||
subprocess.run(
|
||||
["docker", "rm", "-f", container_name], capture_output=True, check=False
|
||||
)
|
||||
|
||||
|
||||
def start_job(
|
||||
*,
|
||||
mode: JOBMode,
|
||||
cache_root: Path = DEFAULT_CACHE_ROOT,
|
||||
postgres_socket_dir: Path = DEFAULT_POSTGRES_SOCKET_DIR,
|
||||
detach: bool = False,
|
||||
) -> None:
|
||||
"""Run BERTopic training or inference in Docker on Jeeves."""
|
||||
cache_root = cache_root.resolve()
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
postgres_socket_dir = postgres_socket_dir.resolve()
|
||||
stop_job(mode=mode)
|
||||
use_postgres_socket = postgres_socket_dir.exists()
|
||||
|
||||
command = [
|
||||
"docker",
|
||||
"run",
|
||||
"--name",
|
||||
_container_name(mode),
|
||||
"--ipc=host",
|
||||
"-v",
|
||||
f"{cache_root}:/cache",
|
||||
*_env_args(use_postgres_socket=use_postgres_socket),
|
||||
IMAGE_NAME,
|
||||
_module_name(mode),
|
||||
]
|
||||
if use_postgres_socket:
|
||||
command[7:7] = ["-v", f"{postgres_socket_dir}:{DEFAULT_POSTGRES_SOCKET_DIR}"]
|
||||
if detach:
|
||||
command.insert(2, "-d")
|
||||
|
||||
logger.info("Starting BERTopic %s container", mode)
|
||||
logger.info(" Cache root: %s", cache_root)
|
||||
if use_postgres_socket:
|
||||
logger.info(" Postgres socket: %s", postgres_socket_dir)
|
||||
result = subprocess.run(command, text=True, capture_output=detach, check=False)
|
||||
if result.returncode != 0:
|
||||
detail = (
|
||||
result.stderr.strip() if result.stderr else f"exit code {result.returncode}"
|
||||
)
|
||||
raise RuntimeError(f"BERTopic container failed to start: {detail}")
|
||||
if detach:
|
||||
logger.info("Container started: %s", result.stdout.strip()[:12])
|
||||
else:
|
||||
logger.info("BERTopic %s run complete", mode)
|
||||
|
||||
|
||||
def logs_job(*, mode: JOBMode) -> str | None:
|
||||
"""Return recent logs from the BERTopic container, or None if absent."""
|
||||
result = subprocess.run(
|
||||
["docker", "logs", "--tail", "100", _container_name(mode)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
return result.stdout + result.stderr
|
||||
|
||||
|
||||
@app.command()
|
||||
def build(
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Build the BERTopic Docker image."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
build_image()
|
||||
|
||||
|
||||
@app.command("run")
|
||||
def run_job_command(
|
||||
mode: Annotated[JOBMode, typer.Option(help="Which BERTopic job to run")] = "train",
|
||||
cache_root: Annotated[
|
||||
Path, typer.Option(help="Host path mounted to /cache for model and HF cache")
|
||||
] = DEFAULT_CACHE_ROOT,
|
||||
postgres_socket_dir: Annotated[
|
||||
Path, typer.Option(help="Host Postgres socket directory to mount into the container")
|
||||
] = DEFAULT_POSTGRES_SOCKET_DIR,
|
||||
detach: Annotated[
|
||||
bool, typer.Option(help="Start the container in the background")
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Run BERTopic training or inference inside Docker."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
start_job(
|
||||
mode=mode,
|
||||
cache_root=cache_root,
|
||||
postgres_socket_dir=postgres_socket_dir,
|
||||
detach=detach,
|
||||
)
|
||||
|
||||
|
||||
@app.command("stop")
|
||||
def stop_job_command(
|
||||
mode: Annotated[
|
||||
JOBMode, typer.Option(help="Which BERTopic container to stop")
|
||||
] = "train",
|
||||
) -> None:
|
||||
"""Stop and remove the BERTopic container."""
|
||||
stop_job(mode=mode)
|
||||
|
||||
|
||||
@app.command("logs")
|
||||
def logs_job_command(
|
||||
mode: Annotated[
|
||||
JOBMode, typer.Option(help="Which BERTopic container logs to show")
|
||||
] = "train",
|
||||
) -> None:
|
||||
"""Show recent logs from the BERTopic container."""
|
||||
output = logs_job(mode=mode)
|
||||
if output is None:
|
||||
typer.echo(f"No BERTopic container found for mode={mode}.")
|
||||
raise typer.Exit(code=1)
|
||||
typer.echo(output)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,38 @@
|
||||
FROM python:3.12-bookworm
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PIP_NO_CACHE_DIR=1
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
libgomp1 \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pipelines ./pipelines
|
||||
|
||||
RUN python -m pip install --upgrade pip setuptools wheel && \
|
||||
python -m pip install \
|
||||
torch \
|
||||
--index-url https://download.pytorch.org/whl/cpu && \
|
||||
python -m pip install \
|
||||
typer \
|
||||
sqlalchemy \
|
||||
bertopic \
|
||||
sentence-transformers \
|
||||
scikit-learn \
|
||||
pandas \
|
||||
numpy \
|
||||
"psycopg[binary]"
|
||||
|
||||
ENV HF_HOME=/cache/huggingface
|
||||
ENV TRANSFORMERS_CACHE=/cache/huggingface
|
||||
|
||||
ENTRYPOINT ["python", "-m"]
|
||||
CMD ["pipelines.bert_topic.train"]
|
||||
@@ -0,0 +1,11 @@
|
||||
FROM ghcr.io/unslothai/unsloth:latest
|
||||
|
||||
RUN pip install --no-cache-dir typer
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
||||
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
||||
COPY python/__init__.py python/__init__.py
|
||||
|
||||
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
|
||||
@@ -9,7 +9,7 @@ from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from pipelines.tools.containers.lib import check_gpu_free
|
||||
from pipelines.pipelines.containers.lib import check_gpu_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +27,7 @@ def build_image() -> None:
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
str(REPO_DIR / "python/prompt_bench/Dockerfile.finetune"),
|
||||
str(REPO_DIR / "pipelines/containers/docker_files/Dockerfile.finetune"),
|
||||
"-t",
|
||||
FINETUNE_IMAGE,
|
||||
".",
|
||||
@@ -0,0 +1,574 @@
|
||||
"""Calculate legislator topic scores from bill topics and roll-call votes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import typer
|
||||
from sqlalchemy import (
|
||||
ColumnElement,
|
||||
Integer,
|
||||
Select,
|
||||
and_,
|
||||
case,
|
||||
cast,
|
||||
delete,
|
||||
extract,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
tuple_,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pipelines.congress_vote_context import create_score_run, finalize_score_run
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
BillTopic,
|
||||
BillTopicPosition,
|
||||
LegislatorScore,
|
||||
SubjectType,
|
||||
Vote,
|
||||
VoteClassification,
|
||||
VoteEffect,
|
||||
VoteMeasureLink,
|
||||
VoteMeasureRole,
|
||||
VotePositionMeaning,
|
||||
VoteRelationship,
|
||||
VoteRecord,
|
||||
)
|
||||
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
|
||||
from pipelines.web.scoring import (
|
||||
OPPOSE_POSITIONS,
|
||||
SUPPORT_POSITIONS,
|
||||
normalized_position_expression,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 5_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScoreDiagnostics:
|
||||
"""Counts for the input stages required to calculate legislator scores."""
|
||||
|
||||
bill_topic_rows: int
|
||||
linked_vote_rows: int
|
||||
vote_record_rows: int
|
||||
topic_vote_links: int
|
||||
scorable_vote_records: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LegislatorScoreInput:
|
||||
"""One aggregated score ready to store in legislator_score."""
|
||||
|
||||
legislator_id: int
|
||||
year: int
|
||||
topic: str
|
||||
score: float
|
||||
supportive: int
|
||||
opposed: int
|
||||
|
||||
|
||||
def create_legislator_score_query(
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> Select:
|
||||
"""Build the aggregate score query from extracted bill topics and vote records."""
|
||||
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||
supportive_vote = _supportive_vote_expression(normalized_vote)
|
||||
opposed_vote = _opposed_vote_expression(normalized_vote)
|
||||
supportive_count = func.sum(supportive_vote)
|
||||
opposed_count = func.sum(opposed_vote)
|
||||
total_count = supportive_count + opposed_count
|
||||
vote_year = cast(extract("year", Vote.vote_date), Integer)
|
||||
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
VoteRecord.legislator_id.label("legislator_id"),
|
||||
vote_year.label("year"),
|
||||
BillTopic.topic.label("topic"),
|
||||
score,
|
||||
supportive_count.label("supportive"),
|
||||
opposed_count.label("opposed"),
|
||||
total_count.label("total"),
|
||||
)
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*_eligible_vote_filters(),
|
||||
_is_scorable_position(normalized_vote),
|
||||
)
|
||||
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||
.having(total_count > 0)
|
||||
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Vote.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||
|
||||
normalized_topics = _normalize_topics(topics)
|
||||
if normalized_topics:
|
||||
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def collect_legislator_scores(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> list[LegislatorScoreInput]:
|
||||
"""Run the aggregate query and return score rows."""
|
||||
rows = session.execute(
|
||||
create_legislator_score_query(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
)
|
||||
return [
|
||||
LegislatorScoreInput(
|
||||
legislator_id=int(row.legislator_id),
|
||||
year=int(row.year),
|
||||
topic=str(row.topic),
|
||||
score=float(row.score),
|
||||
supportive=int(row.supportive),
|
||||
opposed=int(row.opposed),
|
||||
)
|
||||
for row in rows
|
||||
if row.score is not None
|
||||
]
|
||||
|
||||
|
||||
def collect_score_diagnostics(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> ScoreDiagnostics:
|
||||
"""Count score pipeline inputs for explaining empty score runs."""
|
||||
normalized_topics = _normalize_topics(topics)
|
||||
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
|
||||
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
|
||||
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||
eligible_vote_filters = _eligible_vote_filters()
|
||||
|
||||
bill_topic_rows = session.scalar(
|
||||
select(func.count(BillTopic.id)).where(*topic_filters)
|
||||
)
|
||||
linked_vote_rows = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(Vote)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.where(*vote_filters, *eligible_vote_filters)
|
||||
)
|
||||
vote_record_rows = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.where(*vote_filters, *eligible_vote_filters)
|
||||
)
|
||||
topic_vote_links = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
|
||||
)
|
||||
scorable_vote_records = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*vote_filters,
|
||||
*topic_filters,
|
||||
*eligible_vote_filters,
|
||||
_is_scorable_position(normalized_vote),
|
||||
)
|
||||
)
|
||||
|
||||
return ScoreDiagnostics(
|
||||
bill_topic_rows=bill_topic_rows or 0,
|
||||
linked_vote_rows=linked_vote_rows or 0,
|
||||
vote_record_rows=vote_record_rows or 0,
|
||||
topic_vote_links=topic_vote_links or 0,
|
||||
scorable_vote_records=scorable_vote_records or 0,
|
||||
)
|
||||
|
||||
|
||||
def store_legislator_scores(
|
||||
session: Session,
|
||||
rows: Sequence[LegislatorScoreInput],
|
||||
*,
|
||||
score_run_id: int | None,
|
||||
replace_all: bool = False,
|
||||
) -> int:
|
||||
"""Replace matching score rows and insert the newly calculated scores."""
|
||||
if replace_all:
|
||||
session.execute(delete(LegislatorScore))
|
||||
elif rows:
|
||||
keys = [
|
||||
(row.legislator_id, row.year, row.topic)
|
||||
for row in rows
|
||||
]
|
||||
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
|
||||
session.execute(
|
||||
delete(LegislatorScore).where(
|
||||
tuple_(
|
||||
LegislatorScore.legislator_id,
|
||||
LegislatorScore.year,
|
||||
LegislatorScore.topic,
|
||||
).in_(key_batch)
|
||||
)
|
||||
)
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
LegislatorScore(
|
||||
legislator_id=row.legislator_id,
|
||||
year=row.year,
|
||||
topic=row.topic,
|
||||
score=row.score,
|
||||
score_run_id=score_run_id,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
|
||||
def _supportive_vote_expression(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
) -> ColumnElement[int]:
|
||||
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||
return case(
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.FOR,
|
||||
supports_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||
opposes_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
|
||||
|
||||
def _opposed_vote_expression(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
) -> ColumnElement[int]:
|
||||
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||
return case(
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.FOR,
|
||||
opposes_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||
supports_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
|
||||
|
||||
def _position_matches_effect(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
effect: VoteEffect,
|
||||
) -> ColumnElement[bool]:
|
||||
return or_(
|
||||
and_(
|
||||
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
|
||||
VotePositionMeaning.yea_effect == effect,
|
||||
),
|
||||
and_(
|
||||
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
|
||||
VotePositionMeaning.nay_effect == effect,
|
||||
),
|
||||
and_(
|
||||
normalized_vote == "present",
|
||||
VotePositionMeaning.present_effect == effect,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
|
||||
return or_(
|
||||
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
|
||||
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for topic in topics or []:
|
||||
value = normalize_topic_label(topic)
|
||||
if value and value not in seen:
|
||||
normalized.append(value)
|
||||
seen.add(value)
|
||||
return normalized
|
||||
|
||||
|
||||
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
|
||||
return [
|
||||
items[index : index + batch_size]
|
||||
for index in range(0, len(items), batch_size)
|
||||
]
|
||||
|
||||
|
||||
def _vote_scope_filters(
|
||||
*,
|
||||
congress: int | None,
|
||||
bill_ids: Sequence[int] | None,
|
||||
) -> list[ColumnElement[bool]]:
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
if congress is not None:
|
||||
filters.append(Vote.congress == congress)
|
||||
if bill_ids:
|
||||
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||
return filters
|
||||
|
||||
|
||||
def _topic_scope_filters(
|
||||
*,
|
||||
bill_ids: Sequence[int] | None,
|
||||
topics: Sequence[str],
|
||||
) -> list[ColumnElement[bool]]:
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
if bill_ids:
|
||||
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
|
||||
if topics:
|
||||
filters.append(BillTopic.topic.in_(list(topics)))
|
||||
return filters
|
||||
|
||||
|
||||
def _has_score_scope(
|
||||
*,
|
||||
congress: int | None,
|
||||
bill_ids: Sequence[int] | None,
|
||||
topics: Sequence[str] | None,
|
||||
) -> bool:
|
||||
return congress is not None or bool(bill_ids) or bool(topics)
|
||||
|
||||
|
||||
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
|
||||
return [
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
]
|
||||
|
||||
|
||||
def main(
|
||||
congress: Annotated[
|
||||
int | None,
|
||||
typer.Option(help="Only score votes from one Congress."),
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only score votes linked to one internal bill.id. Repeatable.",
|
||||
),
|
||||
] = None,
|
||||
topics: Annotated[
|
||||
list[str] | None,
|
||||
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
|
||||
] = None,
|
||||
replace_all: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
help="Delete every existing legislator score before inserting. "
|
||||
"Unfiltered runs do this automatically."
|
||||
),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Calculate scores without writing to the database."),
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
diagnose: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Log input-stage counts even when rows are calculated."),
|
||||
] = False,
|
||||
) -> None:
|
||||
"""CLI entrypoint for calculating and storing legislator topic scores."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
rows = collect_legislator_scores(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
logger.info("Calculated %d legislator topic score rows", len(rows))
|
||||
if diagnose or not rows:
|
||||
diagnostics = collect_score_diagnostics(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
_log_diagnostics(diagnostics)
|
||||
|
||||
if dry_run:
|
||||
session.rollback()
|
||||
return
|
||||
|
||||
score_run = create_score_run(session)
|
||||
should_replace_all = replace_all or not _has_score_scope(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
written = store_legislator_scores(
|
||||
session,
|
||||
rows,
|
||||
score_run_id=score_run.id,
|
||||
replace_all=should_replace_all,
|
||||
)
|
||||
included_vote_count = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
|
||||
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
|
||||
*_eligible_vote_filters(),
|
||||
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
|
||||
)
|
||||
) or 0
|
||||
total_scoped_votes = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(Vote)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
|
||||
) or 0
|
||||
finalize_score_run(
|
||||
session,
|
||||
score_run=score_run,
|
||||
included_vote_count=included_vote_count,
|
||||
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Stored %d legislator topic score rows", written)
|
||||
|
||||
|
||||
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
|
||||
logger.info(
|
||||
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
|
||||
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
|
||||
diagnostics.bill_topic_rows,
|
||||
diagnostics.linked_vote_rows,
|
||||
diagnostics.vote_record_rows,
|
||||
diagnostics.topic_vote_links,
|
||||
diagnostics.scorable_vote_records,
|
||||
)
|
||||
if diagnostics.bill_topic_rows == 0:
|
||||
logger.warning(
|
||||
"No extracted bill topics matched the score scope. Run "
|
||||
"pipelines.tools.extract_bill_topics after bill summarization."
|
||||
)
|
||||
elif diagnostics.linked_vote_rows == 0:
|
||||
logger.warning("No direct substantive text votes matched the score scope.")
|
||||
elif diagnostics.vote_record_rows == 0:
|
||||
logger.warning("No individual vote records matched the score scope.")
|
||||
elif diagnostics.topic_vote_links == 0:
|
||||
logger.warning(
|
||||
"Bill topics exist, but none are attached to bills that have eligible scored votes."
|
||||
)
|
||||
elif diagnostics.scorable_vote_records == 0:
|
||||
logger.warning(
|
||||
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -0,0 +1,682 @@
|
||||
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
BillTopic,
|
||||
BillTopicPosition,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
VoteTextTarget,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||
REQUEST_TIMEOUT_SECONDS = 60
|
||||
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||
|
||||
|
||||
class TopicExtractionError(RuntimeError):
|
||||
"""Raised when a topic extraction request or response is invalid."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicCatalog:
|
||||
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||
|
||||
topics_by_category: dict[str, list[str]]
|
||||
candidate_topics: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicExtractionDiagnostics:
|
||||
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||
|
||||
bill_rows: int
|
||||
bill_text_rows: int
|
||||
summarized_bill_text_rows: int
|
||||
bills_with_summaries: int
|
||||
bill_topic_rows: int
|
||||
selected_bills: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtractedBillTopic:
|
||||
"""One extracted bill topic and yes-vote stance."""
|
||||
|
||||
topic: str
|
||||
support_position: BillTopicPosition
|
||||
confidence: float | None = None
|
||||
evidence: str | None = None
|
||||
|
||||
|
||||
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||
for bill_text in bill.bill_texts:
|
||||
if bill_text.summary and bill_text.summary.strip():
|
||||
return bill_text
|
||||
return None
|
||||
|
||||
|
||||
def normalize_topic_label(value: str) -> str:
|
||||
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||
normalized = value.strip().strip("\"'")
|
||||
normalized = normalized.strip().rstrip(".").strip()
|
||||
return re.sub(r"\s+", " ", normalized).lower()
|
||||
|
||||
|
||||
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||
topics_path = path or DEFAULT_TOPICS_PATH
|
||||
try:
|
||||
raw = json.loads(topics_path.read_text())
|
||||
except FileNotFoundError as exc:
|
||||
msg = f"Topic catalog not found: {topics_path}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
topics_by_category: dict[str, list[str]] = {}
|
||||
candidate_topics: list[str] = []
|
||||
seen_topics: set[str] = set()
|
||||
|
||||
for category, topics in raw.items():
|
||||
if not isinstance(category, str) or not category.strip():
|
||||
msg = "Topic catalog category names must be non-empty strings"
|
||||
raise TopicExtractionError(msg)
|
||||
if not isinstance(topics, list):
|
||||
msg = f"Topic catalog category {category!r} must contain a list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
normalized_topics: list[str] = []
|
||||
for topic in topics:
|
||||
if not isinstance(topic, str):
|
||||
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||
raise TopicExtractionError(msg)
|
||||
normalized_topic = normalize_topic_label(topic)
|
||||
if not normalized_topic:
|
||||
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||
raise TopicExtractionError(msg)
|
||||
if normalized_topic in seen_topics:
|
||||
continue
|
||||
seen_topics.add(normalized_topic)
|
||||
normalized_topics.append(normalized_topic)
|
||||
candidate_topics.append(normalized_topic)
|
||||
|
||||
topics_by_category[category.strip()] = normalized_topics
|
||||
|
||||
return TopicCatalog(
|
||||
topics_by_category=topics_by_category,
|
||||
candidate_topics=candidate_topics,
|
||||
)
|
||||
|
||||
|
||||
def build_topic_extraction_messages(
|
||||
*,
|
||||
bill: Bill,
|
||||
bill_text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build GPT messages for extracting a bill's scored topics."""
|
||||
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||
metadata = "\n".join(
|
||||
(
|
||||
f"Congress: {bill.congress}",
|
||||
f"Bill: {bill.bill_type} {bill.number}",
|
||||
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||
)
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You extract policy topics from U.S. congressional bills.\n"
|
||||
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||
"Select only topics from the provided candidate topic list.\n"
|
||||
"Omit topics that are not materially addressed by the bill.\n"
|
||||
"Return strict JSON only, with this shape:\n"
|
||||
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||
)
|
||||
user_prompt = "\n\n".join(
|
||||
(
|
||||
"BILL METADATA:",
|
||||
metadata,
|
||||
"CANDIDATE TOPICS:",
|
||||
candidate_list,
|
||||
"BILL TEXT:",
|
||||
bill_text,
|
||||
)
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
def call_openai_topic_extraction(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Call GPT and return the assistant message content."""
|
||||
|
||||
response = httpx.post(
|
||||
openai_config.openai_chat_completions_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {openai_config.api_key}",
|
||||
"OpenAI-Project": openai_config.openai_project_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "gpt-5.4-mini",
|
||||
"messages": messages,
|
||||
},
|
||||
timeout=openai_config.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return extract_message_content(response.json())
|
||||
|
||||
|
||||
def extract_message_content(data: dict[str, Any]) -> str:
|
||||
"""Extract message content from a chat-completions response body."""
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat completion response did not contain choices"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat completion choice must be an object"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
message = first.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
if isinstance(first.get("text"), str):
|
||||
return first["text"]
|
||||
|
||||
msg = "Chat completion response did not contain message content"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||
payload = _load_json_response(response_text)
|
||||
topics = payload.get("topics")
|
||||
if not isinstance(topics, list):
|
||||
msg = "Topic extraction response must contain a topics list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||
for item in topics:
|
||||
if not isinstance(item, dict):
|
||||
msg = "Topic extraction response topics must be objects"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_topic = _extract_topic_label(item)
|
||||
topic = normalize_topic_label(raw_topic)
|
||||
if not topic:
|
||||
msg = "Topic extraction response topic must not be blank"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_position = item.get("support_position")
|
||||
try:
|
||||
support_position = BillTopicPosition(raw_position)
|
||||
except ValueError as exc:
|
||||
msg = f"Invalid support_position: {raw_position!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
confidence = _parse_confidence(item.get("confidence"))
|
||||
evidence = item.get("evidence")
|
||||
if evidence is not None and not isinstance(evidence, str):
|
||||
evidence = str(evidence)
|
||||
|
||||
extracted = ExtractedBillTopic(
|
||||
topic=topic,
|
||||
support_position=support_position,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
)
|
||||
key = (topic, support_position)
|
||||
existing = deduped.get(key)
|
||||
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||
deduped[key] = extracted
|
||||
|
||||
return list(deduped.values())
|
||||
|
||||
|
||||
def extract_topics_for_bill_text(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
bill: Bill,
|
||||
text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[ExtractedBillTopic]:
|
||||
"""Extract accepted catalog topics for a bill text string."""
|
||||
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||
messages = build_topic_extraction_messages(
|
||||
bill=bill,
|
||||
bill_text=text,
|
||||
candidate_topics=sorted(normalized_candidates),
|
||||
)
|
||||
response_text = call_openai_topic_extraction(
|
||||
openai_config=openai_config,
|
||||
messages=messages,
|
||||
)
|
||||
extracted_topics = parse_topic_extraction_response(response_text)
|
||||
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||
|
||||
|
||||
def store_bill_topic_result(
|
||||
*,
|
||||
session: Session,
|
||||
bill: Bill,
|
||||
topics: Sequence[ExtractedBillTopic],
|
||||
replace_existing: bool = True,
|
||||
) -> None:
|
||||
"""Store extracted topics for one bill."""
|
||||
if replace_existing:
|
||||
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||
|
||||
for topic in topics:
|
||||
session.add(
|
||||
BillTopic(
|
||||
bill_id=bill.id,
|
||||
topic=normalize_topic_label(topic.topic),
|
||||
support_position=topic.support_position,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_select_bills_for_topic_extraction(
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select[tuple[Bill]]:
|
||||
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||
BillText.bill_id == Bill.id,
|
||||
*has_summary,
|
||||
]
|
||||
if with_votes_only:
|
||||
summarized_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||
stmt = (
|
||||
select(Bill)
|
||||
.where(summarized_text_exists)
|
||||
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||
.order_by(Bill.id)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
selected_text_exists = exists(
|
||||
select(BillText.id).where(
|
||||
BillText.bill_id == Bill.id,
|
||||
BillText.id.in_(bill_text_ids),
|
||||
*summarized_text_filters[1:],
|
||||
)
|
||||
)
|
||||
stmt = stmt.where(selected_text_exists)
|
||||
if not force:
|
||||
stmt = stmt.where(
|
||||
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def collect_topic_extraction_diagnostics(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> TopicExtractionDiagnostics:
|
||||
"""Count topic extraction inputs for explaining empty selections."""
|
||||
bill_filters = []
|
||||
bill_text_filters: list[ColumnElement[bool]] = []
|
||||
if congress is not None:
|
||||
bill_filters.append(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
bill_filters.append(Bill.id.in_(bill_ids))
|
||||
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||
if with_votes_only:
|
||||
bill_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summary_filters = [*bill_text_filters, *has_summary]
|
||||
|
||||
bills_with_summaries = session.scalar(
|
||||
select(func.count(func.distinct(Bill.id)))
|
||||
.select_from(Bill)
|
||||
.join(BillText, BillText.bill_id == Bill.id)
|
||||
.where(*bill_filters, *summary_filters)
|
||||
)
|
||||
selected_bills = session.scalar(
|
||||
select(func.count()).select_from(
|
||||
create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
).subquery()
|
||||
)
|
||||
)
|
||||
|
||||
return TopicExtractionDiagnostics(
|
||||
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||
bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=bill_text_filters,
|
||||
),
|
||||
summarized_bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=summary_filters,
|
||||
),
|
||||
bills_with_summaries=bills_with_summaries or 0,
|
||||
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||
selected_bills=selected_bills or 0,
|
||||
)
|
||||
|
||||
|
||||
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||
text = response_text.strip()
|
||||
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||
if fenced:
|
||||
text = fenced.group(1).strip()
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
if not isinstance(payload, dict):
|
||||
msg = "Topic extraction response must be a JSON object"
|
||||
raise TopicExtractionError(msg)
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_confidence(raw: Any) -> float | None:
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError) as exc:
|
||||
msg = f"Invalid confidence: {raw!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
|
||||
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||
if topic.confidence is None:
|
||||
return (0, 0.0)
|
||||
return (1, topic.confidence)
|
||||
|
||||
|
||||
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||
raw_topic = item.get("topic")
|
||||
if isinstance(raw_topic, str):
|
||||
return raw_topic
|
||||
if isinstance(raw_topic, dict):
|
||||
for key in ("topic", "label", "name", "title"):
|
||||
value = raw_topic.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
msg = "Topic extraction response topic must be a string"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def _count_bill_texts(
|
||||
session: Session,
|
||||
*,
|
||||
bill_filters: Sequence[ColumnElement[bool]],
|
||||
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||
) -> int:
|
||||
stmt = select(func.count(BillText.id))
|
||||
if bill_filters:
|
||||
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||
|
||||
|
||||
def main(
|
||||
topics_path: Annotated[
|
||||
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||
] = DEFAULT_TOPICS_PATH,
|
||||
congress: Annotated[
|
||||
int | None, typer.Option(help="Only process one Congress.")
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||
),
|
||||
] = None,
|
||||
bill_text_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-text-id",
|
||||
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||
),
|
||||
] = None,
|
||||
with_votes_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--with-votes-only",
|
||||
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||
),
|
||||
] = True,
|
||||
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||
] = False,
|
||||
diagnose: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Log input-stage counts before processing."),
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
) -> None:
|
||||
"""CLI entrypoint for generating and storing bill topics."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
topic_catalog = load_topic_catalog(topics_path)
|
||||
logger.info(
|
||||
"Loaded %d candidate topics from %s",
|
||||
len(topic_catalog.candidate_topics),
|
||||
topics_path,
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
if diagnose or dry_run:
|
||||
diagnostics = collect_topic_extraction_diagnostics(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
_log_topic_extraction_diagnostics(diagnostics)
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
openai_config = get_openai_config()
|
||||
|
||||
stmt = create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
bills = session.scalars(stmt).all()
|
||||
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||
|
||||
written = 0
|
||||
failed = 0
|
||||
for index, bill in enumerate(bills, 1):
|
||||
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||
if bill_text is None:
|
||||
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||
continue
|
||||
summary = bill_text.summary.strip()
|
||||
|
||||
try:
|
||||
extracted_topics = extract_topics_for_bill_text(
|
||||
openai_config=openai_config,
|
||||
bill=bill,
|
||||
text=summary,
|
||||
candidate_topics=topic_catalog.candidate_topics,
|
||||
)
|
||||
except (httpx.HTTPError, TopicExtractionError):
|
||||
failed += 1
|
||||
logger.exception(
|
||||
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||
)
|
||||
continue
|
||||
|
||||
store_bill_topic_result(
|
||||
session=session,
|
||||
bill=bill,
|
||||
topics=extracted_topics,
|
||||
replace_existing=True,
|
||||
)
|
||||
written += 1
|
||||
if index % 100 == 0:
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Stored %d topics for bill id=%s",
|
||||
len(extracted_topics),
|
||||
bill.id,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Done: stored topic results for %d bills; failed %d bills",
|
||||
written,
|
||||
failed,
|
||||
)
|
||||
|
||||
|
||||
def _log_topic_extraction_diagnostics(
|
||||
diagnostics: TopicExtractionDiagnostics,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||
"bill_topic_rows=%d selected_bills=%d",
|
||||
diagnostics.bill_rows,
|
||||
diagnostics.bill_text_rows,
|
||||
diagnostics.summarized_bill_text_rows,
|
||||
diagnostics.bills_with_summaries,
|
||||
diagnostics.bill_topic_rows,
|
||||
diagnostics.selected_bills,
|
||||
)
|
||||
if diagnostics.bill_rows == 0:
|
||||
logger.warning("No bills matched the topic extraction scope.")
|
||||
elif diagnostics.bill_text_rows == 0:
|
||||
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||
elif diagnostics.summarized_bill_text_rows == 0:
|
||||
logger.warning(
|
||||
"No summarized bill_text rows matched the topic extraction scope. "
|
||||
"Run pipelines.tools.summarize_bills first."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||
logger.warning(
|
||||
"No bills selected because matching bills already have topics. "
|
||||
"Use --force to regenerate them."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0:
|
||||
logger.warning("No bills selected for topic extraction.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -0,0 +1,309 @@
|
||||
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tomllib
|
||||
from os import getenv
|
||||
from typing import Annotated, Any
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import Select, exists, or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from tiktoken import get_encoding
|
||||
|
||||
|
||||
from pipelines.config import get_config_dir
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
VoteTextTarget,
|
||||
)
|
||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
REQUEST_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
def load_summarization_prompts(
|
||||
section: str = "summarization",
|
||||
) -> dict[str, str]:
|
||||
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||
|
||||
return tomllib.loads(summarization_prompts.read_text())[section]
|
||||
|
||||
|
||||
class BillSummaryError(RuntimeError):
|
||||
"""Raised when a bill summary request or response is invalid."""
|
||||
|
||||
|
||||
def call_openai_summary(
|
||||
*,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Call GPT and return the assistant message content."""
|
||||
api_key = getenv("CLOSEDAI_TOKEN")
|
||||
if not api_key:
|
||||
msg = "CLOSEDAI_TOKEN is required"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
response = httpx.post(
|
||||
OPENAI_CHAT_COMPLETIONS_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
logger.info(f"{response.text=}")
|
||||
response.raise_for_status()
|
||||
return extract_message_content(response.json())
|
||||
|
||||
|
||||
def build_bill_summary_messages(
|
||||
*,
|
||||
bill_text: BillText,
|
||||
summarization_prompts: dict[str, str],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build the GPT prompt messages plus compressed text and user prompt."""
|
||||
if not bill_text.text_content:
|
||||
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
compressed_text = compress_bill_text(bill_text.text_content)
|
||||
if not compressed_text:
|
||||
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
user_prompt = summarization_prompts["user_template"].format(
|
||||
text_content=compressed_text
|
||||
)
|
||||
|
||||
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
|
||||
logger.info(f"{user_prompt_tokens=}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
return messages, user_prompt_tokens
|
||||
|
||||
|
||||
def summarize_bill_text(
|
||||
*,
|
||||
model: str,
|
||||
bill_text: BillText,
|
||||
summarization_prompts: dict[str, str],
|
||||
) -> str:
|
||||
"""Generate and return a summary for one bill_text row."""
|
||||
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||
bill_text=bill_text,
|
||||
summarization_prompts=summarization_prompts,
|
||||
)
|
||||
# This may only be for gpt-5.4 mini I need to read the docs
|
||||
if user_prompt_tokens > 272000:
|
||||
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
summary = call_openai_summary(
|
||||
model=model,
|
||||
messages=messages,
|
||||
).strip()
|
||||
if not summary:
|
||||
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||
raise BillSummaryError(msg)
|
||||
return summary
|
||||
|
||||
|
||||
def store_bill_summary_result(
|
||||
*,
|
||||
bill_text: BillText,
|
||||
summary: str,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||
bill_text.summary = summary
|
||||
bill_text.summarization_model = model
|
||||
bill_text.summarization_system_prompt_version = "v1.2"
|
||||
bill_text.summarization_user_prompt_version = "v1"
|
||||
|
||||
|
||||
def create_select_bill_texts_for_summarization(
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select:
|
||||
"""Select bill_text rows that have source text and need summaries."""
|
||||
stmt = (
|
||||
select(BillText)
|
||||
.join(Bill, Bill.id == BillText.bill_id)
|
||||
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||
.options(selectinload(BillText.bill))
|
||||
.order_by(BillText.id)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||
if with_votes_only:
|
||||
stmt = stmt.where(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
if not force:
|
||||
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def extract_message_content(data: dict[str, Any]) -> str:
|
||||
"""Extract message content from a chat-completions response body."""
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat completion response did not contain choices"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat completion choice must be an object"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
message = first.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
if isinstance(first.get("text"), str):
|
||||
return first["text"]
|
||||
|
||||
msg = "Chat completion response did not contain message content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
|
||||
def main(
|
||||
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||
congress: Annotated[
|
||||
int | None, typer.Option(help="Only process one Congress.")
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||
),
|
||||
] = None,
|
||||
bill_text_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-text-id",
|
||||
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||
),
|
||||
] = None,
|
||||
with_votes_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--with-votes-only",
|
||||
help="Only process bill_text rows linked to at least one vote.",
|
||||
),
|
||||
] = False,
|
||||
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
) -> None:
|
||||
"""CLI entrypoint for generating and storing bill summaries."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
if not getenv("CLOSEDAI_TOKEN"):
|
||||
message = "CLOSEDAI_TOKEN is required"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
summarization_prompts = load_summarization_prompts()
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
stmt = create_select_bill_texts_for_summarization(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
bill_texts = session.scalars(stmt).all()
|
||||
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||
|
||||
written = 0
|
||||
for index, bill_text in enumerate(bill_texts, 1):
|
||||
summary = summarize_bill_text(
|
||||
model=model,
|
||||
bill_text=bill_text,
|
||||
summarization_prompts=summarization_prompts,
|
||||
)
|
||||
if summary is None:
|
||||
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||
continue
|
||||
store_bill_summary_result(
|
||||
bill_text=bill_text,
|
||||
summary=summary,
|
||||
model=model,
|
||||
)
|
||||
if index % 100 == 0:
|
||||
session.commit()
|
||||
written += 1
|
||||
session.commit()
|
||||
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||
|
||||
logger.info("Done: stored %d summaries", written)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -1,25 +0,0 @@
|
||||
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
||||
#
|
||||
# Build:
|
||||
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
|
||||
#
|
||||
# Run:
|
||||
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
||||
# -v $(pwd)/output:/workspace/output \
|
||||
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
|
||||
# -v /zfs/models/hf:/models \
|
||||
# bill-finetune \
|
||||
# --dataset /workspace/dataset.jsonl \
|
||||
# --output-dir /workspace/output/qwen-bill-summarizer
|
||||
|
||||
FROM ghcr.io/unslothai/unsloth:latest
|
||||
|
||||
RUN pip install --no-cache-dir typer
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
||||
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
||||
COPY python/__init__.py python/__init__.py
|
||||
|
||||
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
|
||||
Reference in New Issue
Block a user