Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d235babcf3 | |||
| a956c4a973 | |||
| 45bdd7b629 | |||
| b5f2df6ae5 | |||
| 21448eb515 | |||
| 28993213af | |||
| d4c587362d | |||
| d0e865ffbd | |||
| 297d9ce89b | |||
| 72eb2d8c3d | |||
| e75c077e16 | |||
| 37fb68ac7e | |||
| e8bafbd589 | |||
| caff8724af | |||
| e1beffef12 | |||
| 2facb82bd4 |
@@ -0,0 +1,16 @@
|
|||||||
|
.git
|
||||||
|
.pytest_cache
|
||||||
|
.ruff_cache
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.venv
|
||||||
|
venv
|
||||||
|
env
|
||||||
|
ENV
|
||||||
|
.env
|
||||||
|
dist
|
||||||
|
build
|
||||||
|
htmlcov
|
||||||
|
coverage.xml
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# Postgres used by the FastAPI app
|
||||||
|
DATA_SCIENCE_DEV_DB=your_existing_database
|
||||||
|
DATA_SCIENCE_DEV_HOST=your_existing_postgres_host
|
||||||
|
DATA_SCIENCE_DEV_PORT=5432
|
||||||
|
DATA_SCIENCE_DEV_USER=your_existing_postgres_user
|
||||||
|
DATA_SCIENCE_DEV_PASSWORD=your_existing_postgres_password
|
||||||
|
|
||||||
|
# WorkOS AuthKit
|
||||||
|
WORKOS_API_KEY=sk_test_your_workos_api_key
|
||||||
|
WORKOS_CLIENT_ID=client_your_workos_client_id
|
||||||
|
WORKOS_COOKIE_PASSWORD=replace_with_a_long_random_secret_at_least_32_chars
|
||||||
|
WORKOS_ORGANIZATION_ID=org_your_workspace_org_id
|
||||||
|
WORKOS_REDIRECT_URI=http://localhost:8000/callback
|
||||||
|
WORKOS_LOGOUT_REDIRECT_URI=http://localhost:8000/
|
||||||
|
WORKOS_SESSION_COOKIE_NAME=workos_session
|
||||||
|
|
||||||
|
# Optional local port overrides for Docker Compose
|
||||||
|
WEB_PUBLISHED_PORT=8000
|
||||||
|
|
||||||
|
# Only used if you explicitly start the optional local Postgres profile
|
||||||
|
POSTGRES_PUBLISHED_PORT=5432
|
||||||
+25
@@ -0,0 +1,25 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends libpq5 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY pyproject.toml /app/pyproject.toml
|
||||||
|
COPY __init__.py /app/__init__.py
|
||||||
|
COPY alembic /app/alembic
|
||||||
|
COPY database_cli.py /app/database_cli.py
|
||||||
|
COPY pipelines /app/pipelines
|
||||||
|
COPY docker /app/docker
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
RUN chmod +x /app/docker/web-entrypoint.sh
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["/app/docker/web-entrypoint.sh"]
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: postgres:16
|
||||||
|
profiles: ["localdb"]
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: ${DATA_SCIENCE_DEV_DB:-nornsight}
|
||||||
|
POSTGRES_USER: ${DATA_SCIENCE_DEV_USER:-nornsight}
|
||||||
|
POSTGRES_PASSWORD: ${DATA_SCIENCE_DEV_PASSWORD:-nornsight}
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test:
|
||||||
|
[
|
||||||
|
"CMD-SHELL",
|
||||||
|
"pg_isready -U ${DATA_SCIENCE_DEV_USER:-nornsight} -d ${DATA_SCIENCE_DEV_DB:-nornsight}",
|
||||||
|
]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 20
|
||||||
|
start_period: 5s
|
||||||
|
ports:
|
||||||
|
- "${POSTGRES_PUBLISHED_PORT:-5432}:5432"
|
||||||
|
|
||||||
|
web:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
restart: unless-stopped
|
||||||
|
dns:
|
||||||
|
- ${WEB_DNS_1:-1.1.1.1}
|
||||||
|
- ${WEB_DNS_2:-8.8.8.8}
|
||||||
|
environment:
|
||||||
|
DATA_SCIENCE_DEV_DB: ${DATA_SCIENCE_DEV_DB}
|
||||||
|
DATA_SCIENCE_DEV_HOST: ${DATA_SCIENCE_DEV_HOST}
|
||||||
|
DATA_SCIENCE_DEV_PORT: ${DATA_SCIENCE_DEV_PORT}
|
||||||
|
DATA_SCIENCE_DEV_USER: ${DATA_SCIENCE_DEV_USER}
|
||||||
|
DATA_SCIENCE_DEV_PASSWORD: ${DATA_SCIENCE_DEV_PASSWORD}
|
||||||
|
WORKOS_API_KEY: ${WORKOS_API_KEY}
|
||||||
|
WORKOS_CLIENT_ID: ${WORKOS_CLIENT_ID}
|
||||||
|
WORKOS_COOKIE_PASSWORD: ${WORKOS_COOKIE_PASSWORD}
|
||||||
|
WORKOS_ORGANIZATION_ID: ${WORKOS_ORGANIZATION_ID}
|
||||||
|
WORKOS_REDIRECT_URI: ${WORKOS_REDIRECT_URI:-http://localhost:8000/callback}
|
||||||
|
WORKOS_LOGOUT_REDIRECT_URI: ${WORKOS_LOGOUT_REDIRECT_URI:-http://localhost:8000/}
|
||||||
|
WORKOS_SESSION_COOKIE_NAME: ${WORKOS_SESSION_COOKIE_NAME:-workos_session}
|
||||||
|
UVICORN_HOST: 0.0.0.0
|
||||||
|
UVICORN_PORT: 8000
|
||||||
|
ports:
|
||||||
|
- "${WEB_PUBLISHED_PORT:-8000}:8000"
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
python - <<'PY'
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
db = os.environ["DATA_SCIENCE_DEV_DB"]
|
||||||
|
host = os.environ["DATA_SCIENCE_DEV_HOST"]
|
||||||
|
port = os.environ["DATA_SCIENCE_DEV_PORT"]
|
||||||
|
user = os.environ["DATA_SCIENCE_DEV_USER"]
|
||||||
|
password = os.environ.get("DATA_SCIENCE_DEV_PASSWORD", "")
|
||||||
|
|
||||||
|
dsn = f"dbname={db} host={host} port={port} user={user} password={password}"
|
||||||
|
|
||||||
|
for attempt in range(60):
|
||||||
|
try:
|
||||||
|
with psycopg.connect(dsn) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("CREATE SCHEMA IF NOT EXISTS main")
|
||||||
|
conn.commit()
|
||||||
|
break
|
||||||
|
except psycopg.OperationalError:
|
||||||
|
if attempt == 59:
|
||||||
|
raise
|
||||||
|
time.sleep(1)
|
||||||
|
PY
|
||||||
|
|
||||||
|
python /app/database_cli.py data_science_dev upgrade head
|
||||||
|
|
||||||
|
exec uvicorn pipelines.web.main:app --host "${UVICORN_HOST:-0.0.0.0}" --port "${UVICORN_PORT:-8000}"
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Init."""
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
"""Nornsight — BERTopic POC Inference Script.
|
||||||
|
|
||||||
|
Loads the trained model and labels a small batch of posts,
|
||||||
|
writing results to main.post_topic for inspection.
|
||||||
|
|
||||||
|
POC: processes a single batch of 1k posts to validate the pipeline end-to-end.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bertopic import BERTopic
|
||||||
|
from sqlalchemy import Engine, func, insert, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.posts import PostTopic, Posts
|
||||||
|
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
|
||||||
|
from pipelines.pipelines.common import configure_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run BERTopic inference against a sample of posts."""
|
||||||
|
configure_logger()
|
||||||
|
|
||||||
|
config = get_bertopic_infer_config()
|
||||||
|
run_inference(config)
|
||||||
|
logger.info(
|
||||||
|
"POC inference complete. Check main.post_topic in DBeaver to inspect results."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_inference(config: BertTopicInferConfig) -> None:
|
||||||
|
model_save_path = Path(config.model_save_path)
|
||||||
|
|
||||||
|
logger.info(f"Loading BERTopic model from {model_save_path}")
|
||||||
|
topic_model = BERTopic.load(str(model_save_path))
|
||||||
|
|
||||||
|
topic_info = topic_model.get_topic_info()
|
||||||
|
label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"]))
|
||||||
|
logger.info(f"Model loaded with {len(label_map)} topics")
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
|
||||||
|
post_ids, texts = get_post_ids_and_test(engine, config)
|
||||||
|
|
||||||
|
logger.info(f"Fetched {len(texts)} posts")
|
||||||
|
|
||||||
|
logger.info("Running BERTopic transform")
|
||||||
|
start = time.perf_counter()
|
||||||
|
topics, _probabilities = topic_model.transform(texts)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
logger.info(f"Transform complete in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
# Write results to main.post_topic
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"post_id": pid,
|
||||||
|
"topic_id": int(topic_id),
|
||||||
|
"topic_label": label_map.get(int(topic_id), "unknown"),
|
||||||
|
"model_version": config.model_version,
|
||||||
|
}
|
||||||
|
for pid, topic_id in zip(post_ids, topics)
|
||||||
|
]
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.execute(insert(PostTopic), records)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
count_topics(records)
|
||||||
|
logger.info(f"Wrote {len(records)} topic labels to main.post_topic")
|
||||||
|
|
||||||
|
|
||||||
|
def get_post_ids_and_test(
|
||||||
|
engine: Engine,
|
||||||
|
config: BertTopicInferConfig,
|
||||||
|
) -> None | tuple[list[int], list[str]]:
|
||||||
|
with Session(engine) as session:
|
||||||
|
logger.info(f"Fetching {config.poc_batch_size} posts for inference")
|
||||||
|
# Pull a fresh batch for inference — distinct from training sample
|
||||||
|
# using a fixed seed offset so we're not re-labeling training posts
|
||||||
|
stmt = select(Posts).where(
|
||||||
|
Posts.text.is_not(None),
|
||||||
|
Posts.langs.in_(ENGLISH_LANGS),
|
||||||
|
func.length(Posts.text) > config.min_text_length,
|
||||||
|
)
|
||||||
|
if config.poc_batch_size > 0:
|
||||||
|
stmt = stmt.limit(config.poc_batch_size)
|
||||||
|
|
||||||
|
posts = session.scalars(stmt).all()
|
||||||
|
if not posts:
|
||||||
|
logger.warning("No posts were selected for inference")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
post_ids = [post.post_id for post in posts]
|
||||||
|
texts = [post.text.strip() for post in posts]
|
||||||
|
|
||||||
|
return post_ids, texts
|
||||||
|
|
||||||
|
|
||||||
|
def count_topics(records: list[dict]) -> None:
|
||||||
|
topic_counts = Counter(record.get("topic_label", "unknown") for record in records)
|
||||||
|
|
||||||
|
logger.info("Topic distribution in this batch:")
|
||||||
|
for label, count in topic_counts.most_common(10):
|
||||||
|
logger.info(" %s: %d", label, count)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
"""Nornsight — BERTopic POC Training Script.
|
||||||
|
|
||||||
|
Pulls a small stratified sample (~11.5k posts) from main.posts,
|
||||||
|
trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally.
|
||||||
|
|
||||||
|
POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts)
|
||||||
|
Full training rate will be: random() < 0.005 (~1.08M posts)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bertopic import BERTopic
|
||||||
|
from sklearn.cluster import MiniBatchKMeans
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.posts import Posts
|
||||||
|
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
|
||||||
|
from pipelines.pipelines.common import configure_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Train and persist the BERTopic model."""
|
||||||
|
configure_logger()
|
||||||
|
|
||||||
|
config = get_bertopic_train_config()
|
||||||
|
docs = load_sample(config)
|
||||||
|
if not docs:
|
||||||
|
logger.warning("No training documents were selected")
|
||||||
|
return
|
||||||
|
|
||||||
|
train(docs, config)
|
||||||
|
logger.info(f"Done. Model saved as version {config.model_version}")
|
||||||
|
logger.info("Next: run infer.py to label a sample of posts in the database")
|
||||||
|
|
||||||
|
|
||||||
|
def load_sample(config: BertTopicTrainConfig) -> list[str]:
|
||||||
|
logger.info("Connecting to PostgreSQL via SQLAlchemy")
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
|
||||||
|
logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})")
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
texts = session.scalars(
|
||||||
|
select(Posts.text).where(
|
||||||
|
Posts.text.is_not(None),
|
||||||
|
Posts.langs.in_(ENGLISH_LANGS),
|
||||||
|
func.length(Posts.text) > config.min_text_length,
|
||||||
|
func.random() < config.sample_rate,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
# Basic cleaning — strip whitespace and deduplicate
|
||||||
|
docs = list({text.strip() for text in texts})
|
||||||
|
logger.info(f"After cleaning and dedup: {len(docs)} posts")
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def train(docs: list[str], config: BertTopicTrainConfig) -> None:
|
||||||
|
logger.info(
|
||||||
|
f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})"
|
||||||
|
)
|
||||||
|
|
||||||
|
cluster_model = MiniBatchKMeans(
|
||||||
|
n_clusters=config.n_topics,
|
||||||
|
random_state=42,
|
||||||
|
batch_size=1024,
|
||||||
|
n_init=3,
|
||||||
|
verbose=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_model = BERTopic(
|
||||||
|
hdbscan_model=cluster_model,
|
||||||
|
language="english",
|
||||||
|
calculate_probabilities=False, # saves memory
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)")
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
topic_model.fit_transform(docs)
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
|
||||||
|
|
||||||
|
# Log topic summary for quick inspection
|
||||||
|
topic_info = topic_model.get_topic_info()
|
||||||
|
logger.info(f"Topics found: {len(topic_info)}")
|
||||||
|
logger.info(f"\n{topic_info.to_string()}")
|
||||||
|
|
||||||
|
model_save_path = Path(config.model_save_path)
|
||||||
|
model_save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Saving model to {model_save_path}")
|
||||||
|
|
||||||
|
topic_model.save(
|
||||||
|
str(model_save_path),
|
||||||
|
serialization="safetensors",
|
||||||
|
save_ctfidf=True,
|
||||||
|
save_embedding_model=True,
|
||||||
|
)
|
||||||
|
logger.info("Model saved")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,10 +5,8 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from os import getenv
|
|
||||||
from subprocess import PIPE, Popen
|
from subprocess import PIPE, Popen
|
||||||
|
|
||||||
from apprise import Apprise
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,25 +45,6 @@ def bash_wrapper(command: str) -> tuple[str, int]:
|
|||||||
return output.decode(), process.returncode
|
return output.decode(), process.returncode
|
||||||
|
|
||||||
|
|
||||||
def signal_alert(body: str, title: str = "") -> None:
|
|
||||||
"""Send a signal alert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
body (str): The body of the alert.
|
|
||||||
title (str, optional): The title of the alert. Defaults to "".
|
|
||||||
"""
|
|
||||||
apprise_client = Apprise()
|
|
||||||
|
|
||||||
from_phone = getenv("SIGNAL_ALERT_FROM_PHONE")
|
|
||||||
to_phone = getenv("SIGNAL_ALERT_TO_PHONE")
|
|
||||||
if not from_phone or not to_phone:
|
|
||||||
logger.info("SIGNAL_ALERT_FROM_PHONE or SIGNAL_ALERT_TO_PHONE not set")
|
|
||||||
return
|
|
||||||
|
|
||||||
apprise_client.add(f"signal://localhost:8989/{from_phone}/{to_phone}")
|
|
||||||
|
|
||||||
apprise_client.notify(title=title, body=body)
|
|
||||||
|
|
||||||
|
|
||||||
def utcnow() -> datetime:
|
def utcnow() -> datetime:
|
||||||
"""Get the current UTC time."""
|
"""Get the current UTC time."""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
from datetime import date
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
@@ -101,6 +102,44 @@ class OpenAIConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BertTopicTrainConfig:
|
||||||
|
"""BERTopic training configuration loaded from TOML."""
|
||||||
|
|
||||||
|
sample_rate: float
|
||||||
|
min_text_length: int
|
||||||
|
n_topics: int
|
||||||
|
model_save_path: str
|
||||||
|
model_version: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> BertTopicTrainConfig:
|
||||||
|
"""Load BERTopic training config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["bertopic"]["train"]
|
||||||
|
|
||||||
|
today = date.today().isoformat()
|
||||||
|
if raw.get("model_version") is None:
|
||||||
|
raw["model_version"] = (
|
||||||
|
f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}"
|
||||||
|
)
|
||||||
|
return cls(**raw)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BertTopicInferConfig:
|
||||||
|
"""BERTopic inference configuration loaded from TOML."""
|
||||||
|
|
||||||
|
min_text_length: int
|
||||||
|
poc_batch_size: int
|
||||||
|
model_version: str
|
||||||
|
model_save_path: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> BertTopicInferConfig:
|
||||||
|
"""Load BERTopic inference config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"]
|
||||||
|
return cls(**raw)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dir() -> Path:
|
def get_config_dir() -> Path:
|
||||||
"""Get the path to the config directory."""
|
"""Get the path to the config directory."""
|
||||||
return Path(__file__).resolve().parents[2] / "config"
|
return Path(__file__).resolve().parents[2] / "config"
|
||||||
@@ -127,3 +166,19 @@ def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
|
|||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = default_config_path()
|
config_path = default_config_path()
|
||||||
return BenchmarkConfig.from_toml(config_path)
|
return BenchmarkConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bertopic_train_config(
|
||||||
|
config_path: Path | None = None,
|
||||||
|
) -> BertTopicTrainConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return BertTopicTrainConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bertopic_infer_config(
|
||||||
|
config_path: Path | None = None,
|
||||||
|
) -> BertTopicInferConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return BertTopicInferConfig.from_toml(config_path)
|
||||||
|
|||||||
@@ -0,0 +1,197 @@
|
|||||||
|
"""Docker container lifecycle management for the web app stack."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
REPO_DIR = Path(__file__).resolve().parents[2]
|
||||||
|
COMPOSE_FILE = REPO_DIR / "docker-compose.yml"
|
||||||
|
EnvTarget = Literal["all", "web", "db"]
|
||||||
|
REQUIRED_WORKOS_ENV_VARS = (
|
||||||
|
"WORKOS_API_KEY",
|
||||||
|
"WORKOS_CLIENT_ID",
|
||||||
|
"WORKOS_COOKIE_PASSWORD",
|
||||||
|
"WORKOS_ORGANIZATION_ID",
|
||||||
|
)
|
||||||
|
|
||||||
|
app = typer.Typer(help="Web stack container management.")
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_command(*args: str) -> list[str]:
|
||||||
|
"""Build a docker compose command for the repo-local stack."""
|
||||||
|
return ["docker", "compose", "-f", str(COMPOSE_FILE), *args]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_compose(
|
||||||
|
*args: str,
|
||||||
|
capture_output: bool = False,
|
||||||
|
check: bool = True,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run docker compose in the repository root."""
|
||||||
|
result = subprocess.run(
|
||||||
|
_compose_command(*args),
|
||||||
|
cwd=REPO_DIR,
|
||||||
|
text=True,
|
||||||
|
capture_output=capture_output,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if check and result.returncode != 0:
|
||||||
|
detail = result.stderr.strip() if result.stderr else f"exit code {result.returncode}"
|
||||||
|
raise RuntimeError(f"docker compose {' '.join(args)} failed: {detail}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_workos_env() -> None:
|
||||||
|
"""Ensure the web app has the WorkOS env vars it needs before startup."""
|
||||||
|
missing = [name for name in REQUIRED_WORKOS_ENV_VARS if not os.getenv(name)]
|
||||||
|
if missing:
|
||||||
|
message = (
|
||||||
|
"Missing required WorkOS environment variables: "
|
||||||
|
+ ", ".join(missing)
|
||||||
|
+ ". Populate .env before running the web stack."
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
cookie_password = os.getenv("WORKOS_COOKIE_PASSWORD", "")
|
||||||
|
if len(cookie_password) < 32:
|
||||||
|
raise RuntimeError("WORKOS_COOKIE_PASSWORD must be at least 32 characters long.")
|
||||||
|
|
||||||
|
|
||||||
|
def build_stack() -> None:
|
||||||
|
"""Build the web app image."""
|
||||||
|
logger.info("Building web image from %s", COMPOSE_FILE)
|
||||||
|
_run_compose("build", "web", capture_output=False)
|
||||||
|
logger.info("Web image built")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_database_env() -> None:
|
||||||
|
"""Ensure the web app has the database env vars it needs before startup."""
|
||||||
|
required = (
|
||||||
|
"DATA_SCIENCE_DEV_DB",
|
||||||
|
"DATA_SCIENCE_DEV_HOST",
|
||||||
|
"DATA_SCIENCE_DEV_PORT",
|
||||||
|
"DATA_SCIENCE_DEV_USER",
|
||||||
|
)
|
||||||
|
missing = [name for name in required if not os.getenv(name)]
|
||||||
|
if missing:
|
||||||
|
message = (
|
||||||
|
"Missing required database environment variables: "
|
||||||
|
+ ", ".join(missing)
|
||||||
|
+ ". Populate .env before running the web stack."
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
|
||||||
|
def start_stack(
|
||||||
|
*, build: bool = False, detach: bool = False, with_local_db: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Start the web stack, using the existing DB by default."""
|
||||||
|
_validate_workos_env()
|
||||||
|
_validate_database_env()
|
||||||
|
command = ["up"]
|
||||||
|
if build:
|
||||||
|
command.append("--build")
|
||||||
|
if detach:
|
||||||
|
command.append("-d")
|
||||||
|
if with_local_db:
|
||||||
|
command.extend(["--profile", "localdb", "db", "web"])
|
||||||
|
else:
|
||||||
|
command.append("web")
|
||||||
|
logger.info(
|
||||||
|
"Starting web stack%s",
|
||||||
|
" with local Postgres" if with_local_db else " against existing Postgres",
|
||||||
|
)
|
||||||
|
_run_compose(*command, capture_output=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_stack(*, drop_volumes: bool = False) -> None:
|
||||||
|
"""Stop and remove the web stack."""
|
||||||
|
logger.info("Stopping web stack")
|
||||||
|
command = ["down"]
|
||||||
|
if drop_volumes:
|
||||||
|
command.append("--volumes")
|
||||||
|
_run_compose(*command, capture_output=False)
|
||||||
|
|
||||||
|
|
||||||
|
def logs_stack(*, target: EnvTarget = "all", follow: bool = False, tail: int = 100) -> None:
|
||||||
|
"""Show docker compose logs for the web stack."""
|
||||||
|
command = ["logs", "--tail", str(tail)]
|
||||||
|
if follow:
|
||||||
|
command.append("--follow")
|
||||||
|
if target != "all":
|
||||||
|
command.append(target)
|
||||||
|
_run_compose(*command, capture_output=False)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def build(
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Build the web Docker image."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
build_stack()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def run(
|
||||||
|
build: Annotated[
|
||||||
|
bool, typer.Option(help="Rebuild the web image before starting the stack")
|
||||||
|
] = False,
|
||||||
|
detach: Annotated[
|
||||||
|
bool, typer.Option(help="Start the stack in the background")
|
||||||
|
] = False,
|
||||||
|
with_local_db: Annotated[
|
||||||
|
bool, typer.Option(help="Also start the optional local Postgres container")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run the web + Postgres stack."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
start_stack(build=build, detach=detach, with_local_db=with_local_db)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def stop(
|
||||||
|
drop_volumes: Annotated[
|
||||||
|
bool, typer.Option(help="Also delete the Postgres volume")
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Stop and remove the web stack."""
|
||||||
|
stop_stack(drop_volumes=drop_volumes)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def logs(
|
||||||
|
target: Annotated[
|
||||||
|
EnvTarget, typer.Option(help="Which service logs to show")
|
||||||
|
] = "all",
|
||||||
|
follow: Annotated[
|
||||||
|
bool, typer.Option(help="Follow logs until interrupted")
|
||||||
|
] = False,
|
||||||
|
tail: Annotated[int, typer.Option(help="How many recent lines to show")] = 100,
|
||||||
|
) -> None:
|
||||||
|
"""Show recent logs from the web stack."""
|
||||||
|
logs_stack(target=target, follow=follow, tail=tail)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,574 @@
|
|||||||
|
"""Calculate legislator topic scores from bill topics and roll-call votes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import (
|
||||||
|
ColumnElement,
|
||||||
|
Integer,
|
||||||
|
Select,
|
||||||
|
and_,
|
||||||
|
case,
|
||||||
|
cast,
|
||||||
|
delete,
|
||||||
|
extract,
|
||||||
|
func,
|
||||||
|
or_,
|
||||||
|
select,
|
||||||
|
tuple_,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.jobs.congress_vote_context import create_score_run, finalize_score_run
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
LegislatorScore,
|
||||||
|
SubjectType,
|
||||||
|
Vote,
|
||||||
|
VoteClassification,
|
||||||
|
VoteEffect,
|
||||||
|
VoteMeasureLink,
|
||||||
|
VoteMeasureRole,
|
||||||
|
VotePositionMeaning,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteRecord,
|
||||||
|
)
|
||||||
|
from pipelines.jobs.extract_bill_topics import normalize_topic_label
|
||||||
|
from pipelines.web.scoring import (
|
||||||
|
OPPOSE_POSITIONS,
|
||||||
|
SUPPORT_POSITIONS,
|
||||||
|
normalized_position_expression,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DELETE_BATCH_SIZE = 5_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScoreDiagnostics:
|
||||||
|
"""Counts for the input stages required to calculate legislator scores."""
|
||||||
|
|
||||||
|
bill_topic_rows: int
|
||||||
|
linked_vote_rows: int
|
||||||
|
vote_record_rows: int
|
||||||
|
topic_vote_links: int
|
||||||
|
scorable_vote_records: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LegislatorScoreInput:
|
||||||
|
"""One aggregated score ready to store in legislator_score."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
year: int
|
||||||
|
topic: str
|
||||||
|
score: float
|
||||||
|
supportive: int
|
||||||
|
opposed: int
|
||||||
|
|
||||||
|
|
||||||
|
def create_legislator_score_query(
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Build the aggregate score query from extracted bill topics and vote records."""
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
supportive_vote = _supportive_vote_expression(normalized_vote)
|
||||||
|
opposed_vote = _opposed_vote_expression(normalized_vote)
|
||||||
|
supportive_count = func.sum(supportive_vote)
|
||||||
|
opposed_count = func.sum(opposed_vote)
|
||||||
|
total_count = supportive_count + opposed_count
|
||||||
|
vote_year = cast(extract("year", Vote.vote_date), Integer)
|
||||||
|
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
VoteRecord.legislator_id.label("legislator_id"),
|
||||||
|
vote_year.label("year"),
|
||||||
|
BillTopic.topic.label("topic"),
|
||||||
|
score,
|
||||||
|
supportive_count.label("supportive"),
|
||||||
|
opposed_count.label("opposed"),
|
||||||
|
total_count.label("total"),
|
||||||
|
)
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
.having(total_count > 0)
|
||||||
|
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
if normalized_topics:
|
||||||
|
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> list[LegislatorScoreInput]:
|
||||||
|
"""Run the aggregate query and return score rows."""
|
||||||
|
rows = session.execute(
|
||||||
|
create_legislator_score_query(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
LegislatorScoreInput(
|
||||||
|
legislator_id=int(row.legislator_id),
|
||||||
|
year=int(row.year),
|
||||||
|
topic=str(row.topic),
|
||||||
|
score=float(row.score),
|
||||||
|
supportive=int(row.supportive),
|
||||||
|
opposed=int(row.opposed),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
if row.score is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_score_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: Sequence[int] | None = None,
|
||||||
|
topics: Sequence[str] | None = None,
|
||||||
|
) -> ScoreDiagnostics:
|
||||||
|
"""Count score pipeline inputs for explaining empty score runs."""
|
||||||
|
normalized_topics = _normalize_topics(topics)
|
||||||
|
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
|
||||||
|
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
|
||||||
|
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||||
|
eligible_vote_filters = _eligible_vote_filters()
|
||||||
|
|
||||||
|
bill_topic_rows = session.scalar(
|
||||||
|
select(func.count(BillTopic.id)).where(*topic_filters)
|
||||||
|
)
|
||||||
|
linked_vote_rows = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
vote_record_rows = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.where(*vote_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
topic_vote_links = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
|
||||||
|
)
|
||||||
|
scorable_vote_records = session.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*vote_filters,
|
||||||
|
*topic_filters,
|
||||||
|
*eligible_vote_filters,
|
||||||
|
_is_scorable_position(normalized_vote),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoreDiagnostics(
|
||||||
|
bill_topic_rows=bill_topic_rows or 0,
|
||||||
|
linked_vote_rows=linked_vote_rows or 0,
|
||||||
|
vote_record_rows=vote_record_rows or 0,
|
||||||
|
topic_vote_links=topic_vote_links or 0,
|
||||||
|
scorable_vote_records=scorable_vote_records or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def store_legislator_scores(
|
||||||
|
session: Session,
|
||||||
|
rows: Sequence[LegislatorScoreInput],
|
||||||
|
*,
|
||||||
|
score_run_id: int | None,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Replace matching score rows and insert the newly calculated scores."""
|
||||||
|
if replace_all:
|
||||||
|
session.execute(delete(LegislatorScore))
|
||||||
|
elif rows:
|
||||||
|
keys = [
|
||||||
|
(row.legislator_id, row.year, row.topic)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
|
||||||
|
session.execute(
|
||||||
|
delete(LegislatorScore).where(
|
||||||
|
tuple_(
|
||||||
|
LegislatorScore.legislator_id,
|
||||||
|
LegislatorScore.year,
|
||||||
|
LegislatorScore.topic,
|
||||||
|
).in_(key_batch)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
LegislatorScore(
|
||||||
|
legislator_id=row.legislator_id,
|
||||||
|
year=row.year,
|
||||||
|
topic=row.topic,
|
||||||
|
score=row.score,
|
||||||
|
score_run_id=score_run_id,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _supportive_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _opposed_vote_expression(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
) -> ColumnElement[int]:
|
||||||
|
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||||
|
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||||
|
return case(
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.FOR,
|
||||||
|
opposes_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
and_(
|
||||||
|
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||||
|
supports_text,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _position_matches_effect(
|
||||||
|
normalized_vote: ColumnElement[str | None],
|
||||||
|
effect: VoteEffect,
|
||||||
|
) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
|
||||||
|
VotePositionMeaning.yea_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
|
||||||
|
VotePositionMeaning.nay_effect == effect,
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
normalized_vote == "present",
|
||||||
|
VotePositionMeaning.present_effect == effect,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
|
||||||
|
return or_(
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
|
||||||
|
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
|
||||||
|
normalized: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for topic in topics or []:
|
||||||
|
value = normalize_topic_label(topic)
|
||||||
|
if value and value not in seen:
|
||||||
|
normalized.append(value)
|
||||||
|
seen.add(value)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
|
||||||
|
return [
|
||||||
|
items[index : index + batch_size]
|
||||||
|
for index in range(0, len(items), batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _vote_scope_filters(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
filters.append(Vote.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _topic_scope_filters(
|
||||||
|
*,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str],
|
||||||
|
) -> list[ColumnElement[bool]]:
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
if bill_ids:
|
||||||
|
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
|
||||||
|
if topics:
|
||||||
|
filters.append(BillTopic.topic.in_(list(topics)))
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def _has_score_scope(
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
bill_ids: Sequence[int] | None,
|
||||||
|
topics: Sequence[str] | None,
|
||||||
|
) -> bool:
|
||||||
|
return congress is not None or bool(bill_ids) or bool(topics)
|
||||||
|
|
||||||
|
|
||||||
|
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
|
||||||
|
return [
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
congress: Annotated[
|
||||||
|
int | None,
|
||||||
|
typer.Option(help="Only score votes from one Congress."),
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only score votes linked to one internal bill.id. Repeatable.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
topics: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
|
||||||
|
] = None,
|
||||||
|
replace_all: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
help="Delete every existing legislator score before inserting. "
|
||||||
|
"Unfiltered runs do this automatically."
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Calculate scores without writing to the database."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts even when rows are calculated."),
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for calculating and storing legislator topic scores."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
rows = collect_legislator_scores(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
logger.info("Calculated %d legislator topic score rows", len(rows))
|
||||||
|
if diagnose or not rows:
|
||||||
|
diagnostics = collect_score_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
_log_diagnostics(diagnostics)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
session.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
score_run = create_score_run(session)
|
||||||
|
should_replace_all = replace_all or not _has_score_scope(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
written = store_legislator_scores(
|
||||||
|
session,
|
||||||
|
rows,
|
||||||
|
score_run_id=score_run.id,
|
||||||
|
replace_all=should_replace_all,
|
||||||
|
)
|
||||||
|
included_vote_count = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(VoteRecord)
|
||||||
|
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||||
|
.where(
|
||||||
|
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
|
||||||
|
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
|
||||||
|
*_eligible_vote_filters(),
|
||||||
|
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
|
||||||
|
)
|
||||||
|
) or 0
|
||||||
|
total_scoped_votes = session.scalar(
|
||||||
|
select(func.count(func.distinct(Vote.id)))
|
||||||
|
.select_from(Vote)
|
||||||
|
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||||
|
.join(
|
||||||
|
VoteMeasureLink,
|
||||||
|
and_(
|
||||||
|
VoteMeasureLink.vote_id == Vote.id,
|
||||||
|
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
|
||||||
|
) or 0
|
||||||
|
finalize_score_run(
|
||||||
|
session,
|
||||||
|
score_run=score_run,
|
||||||
|
included_vote_count=included_vote_count,
|
||||||
|
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored %d legislator topic score rows", written)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
|
||||||
|
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.linked_vote_rows,
|
||||||
|
diagnostics.vote_record_rows,
|
||||||
|
diagnostics.topic_vote_links,
|
||||||
|
diagnostics.scorable_vote_records,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_topic_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No extracted bill topics matched the score scope. Run "
|
||||||
|
"pipelines.tools.extract_bill_topics after bill summarization."
|
||||||
|
)
|
||||||
|
elif diagnostics.linked_vote_rows == 0:
|
||||||
|
logger.warning("No direct substantive text votes matched the score scope.")
|
||||||
|
elif diagnostics.vote_record_rows == 0:
|
||||||
|
logger.warning("No individual vote records matched the score scope.")
|
||||||
|
elif diagnostics.topic_vote_links == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Bill topics exist, but none are attached to bills that have eligible scored votes."
|
||||||
|
)
|
||||||
|
elif diagnostics.scorable_vote_records == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,682 @@
|
|||||||
|
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any, Sequence
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
BillTopic,
|
||||||
|
BillTopicPosition,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||||
|
|
||||||
|
|
||||||
|
class TopicExtractionError(RuntimeError):
|
||||||
|
"""Raised when a topic extraction request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicCatalog:
|
||||||
|
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]]
|
||||||
|
candidate_topics: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicExtractionDiagnostics:
|
||||||
|
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||||
|
|
||||||
|
bill_rows: int
|
||||||
|
bill_text_rows: int
|
||||||
|
summarized_bill_text_rows: int
|
||||||
|
bills_with_summaries: int
|
||||||
|
bill_topic_rows: int
|
||||||
|
selected_bills: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExtractedBillTopic:
|
||||||
|
"""One extracted bill topic and yes-vote stance."""
|
||||||
|
|
||||||
|
topic: str
|
||||||
|
support_position: BillTopicPosition
|
||||||
|
confidence: float | None = None
|
||||||
|
evidence: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||||
|
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||||
|
for bill_text in bill.bill_texts:
|
||||||
|
if bill_text.summary and bill_text.summary.strip():
|
||||||
|
return bill_text
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_topic_label(value: str) -> str:
|
||||||
|
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||||
|
normalized = value.strip().strip("\"'")
|
||||||
|
normalized = normalized.strip().rstrip(".").strip()
|
||||||
|
return re.sub(r"\s+", " ", normalized).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||||
|
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||||
|
topics_path = path or DEFAULT_TOPICS_PATH
|
||||||
|
try:
|
||||||
|
raw = json.loads(topics_path.read_text())
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
msg = f"Topic catalog not found: {topics_path}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
topics_by_category: dict[str, list[str]] = {}
|
||||||
|
candidate_topics: list[str] = []
|
||||||
|
seen_topics: set[str] = set()
|
||||||
|
|
||||||
|
for category, topics in raw.items():
|
||||||
|
if not isinstance(category, str) or not category.strip():
|
||||||
|
msg = "Topic catalog category names must be non-empty strings"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = f"Topic catalog category {category!r} must contain a list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
normalized_topics: list[str] = []
|
||||||
|
for topic in topics:
|
||||||
|
if not isinstance(topic, str):
|
||||||
|
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
normalized_topic = normalize_topic_label(topic)
|
||||||
|
if not normalized_topic:
|
||||||
|
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
if normalized_topic in seen_topics:
|
||||||
|
continue
|
||||||
|
seen_topics.add(normalized_topic)
|
||||||
|
normalized_topics.append(normalized_topic)
|
||||||
|
candidate_topics.append(normalized_topic)
|
||||||
|
|
||||||
|
topics_by_category[category.strip()] = normalized_topics
|
||||||
|
|
||||||
|
return TopicCatalog(
|
||||||
|
topics_by_category=topics_by_category,
|
||||||
|
candidate_topics=candidate_topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_topic_extraction_messages(
|
||||||
|
*,
|
||||||
|
bill: Bill,
|
||||||
|
bill_text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build GPT messages for extracting a bill's scored topics."""
|
||||||
|
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||||
|
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||||
|
metadata = "\n".join(
|
||||||
|
(
|
||||||
|
f"Congress: {bill.congress}",
|
||||||
|
f"Bill: {bill.bill_type} {bill.number}",
|
||||||
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||||
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You extract policy topics from U.S. congressional bills.\n"
|
||||||
|
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||||
|
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||||
|
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||||
|
"Select only topics from the provided candidate topic list.\n"
|
||||||
|
"Omit topics that are not materially addressed by the bill.\n"
|
||||||
|
"Return strict JSON only, with this shape:\n"
|
||||||
|
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||||
|
)
|
||||||
|
user_prompt = "\n\n".join(
|
||||||
|
(
|
||||||
|
"BILL METADATA:",
|
||||||
|
metadata,
|
||||||
|
"CANDIDATE TOPICS:",
|
||||||
|
candidate_list,
|
||||||
|
"BILL TEXT:",
|
||||||
|
bill_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_topic_extraction(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
openai_config.openai_chat_completions_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {openai_config.api_key}",
|
||||||
|
"OpenAI-Project": openai_config.openai_project_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": "gpt-5.4-mini",
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=openai_config.timeout_seconds,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||||
|
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||||
|
payload = _load_json_response(response_text)
|
||||||
|
topics = payload.get("topics")
|
||||||
|
if not isinstance(topics, list):
|
||||||
|
msg = "Topic extraction response must contain a topics list"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||||
|
for item in topics:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
msg = "Topic extraction response topics must be objects"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_topic = _extract_topic_label(item)
|
||||||
|
topic = normalize_topic_label(raw_topic)
|
||||||
|
if not topic:
|
||||||
|
msg = "Topic extraction response topic must not be blank"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
raw_position = item.get("support_position")
|
||||||
|
try:
|
||||||
|
support_position = BillTopicPosition(raw_position)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = f"Invalid support_position: {raw_position!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
confidence = _parse_confidence(item.get("confidence"))
|
||||||
|
evidence = item.get("evidence")
|
||||||
|
if evidence is not None and not isinstance(evidence, str):
|
||||||
|
evidence = str(evidence)
|
||||||
|
|
||||||
|
extracted = ExtractedBillTopic(
|
||||||
|
topic=topic,
|
||||||
|
support_position=support_position,
|
||||||
|
confidence=confidence,
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
key = (topic, support_position)
|
||||||
|
existing = deduped.get(key)
|
||||||
|
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||||
|
deduped[key] = extracted
|
||||||
|
|
||||||
|
return list(deduped.values())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_topics_for_bill_text(
|
||||||
|
*,
|
||||||
|
openai_config: OpenAIConfig,
|
||||||
|
bill: Bill,
|
||||||
|
text: str,
|
||||||
|
candidate_topics: Sequence[str],
|
||||||
|
) -> list[ExtractedBillTopic]:
|
||||||
|
"""Extract accepted catalog topics for a bill text string."""
|
||||||
|
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||||
|
messages = build_topic_extraction_messages(
|
||||||
|
bill=bill,
|
||||||
|
bill_text=text,
|
||||||
|
candidate_topics=sorted(normalized_candidates),
|
||||||
|
)
|
||||||
|
response_text = call_openai_topic_extraction(
|
||||||
|
openai_config=openai_config,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
extracted_topics = parse_topic_extraction_response(response_text)
|
||||||
|
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_topic_result(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
bill: Bill,
|
||||||
|
topics: Sequence[ExtractedBillTopic],
|
||||||
|
replace_existing: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Store extracted topics for one bill."""
|
||||||
|
if replace_existing:
|
||||||
|
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
session.add(
|
||||||
|
BillTopic(
|
||||||
|
bill_id=bill.id,
|
||||||
|
topic=normalize_topic_label(topic.topic),
|
||||||
|
support_position=topic.support_position,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bills_for_topic_extraction(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select[tuple[Bill]]:
|
||||||
|
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
*has_summary,
|
||||||
|
]
|
||||||
|
if with_votes_only:
|
||||||
|
summarized_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||||
|
stmt = (
|
||||||
|
select(Bill)
|
||||||
|
.where(summarized_text_exists)
|
||||||
|
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||||
|
.order_by(Bill.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
selected_text_exists = exists(
|
||||||
|
select(BillText.id).where(
|
||||||
|
BillText.bill_id == Bill.id,
|
||||||
|
BillText.id.in_(bill_text_ids),
|
||||||
|
*summarized_text_filters[1:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stmt = stmt.where(selected_text_exists)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||||
|
)
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def collect_topic_extraction_diagnostics(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> TopicExtractionDiagnostics:
|
||||||
|
"""Count topic extraction inputs for explaining empty selections."""
|
||||||
|
bill_filters = []
|
||||||
|
bill_text_filters: list[ColumnElement[bool]] = []
|
||||||
|
if congress is not None:
|
||||||
|
bill_filters.append(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
bill_filters.append(Bill.id.in_(bill_ids))
|
||||||
|
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
bill_text_filters.append(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||||
|
summary_filters = [*bill_text_filters, *has_summary]
|
||||||
|
|
||||||
|
bills_with_summaries = session.scalar(
|
||||||
|
select(func.count(func.distinct(Bill.id)))
|
||||||
|
.select_from(Bill)
|
||||||
|
.join(BillText, BillText.bill_id == Bill.id)
|
||||||
|
.where(*bill_filters, *summary_filters)
|
||||||
|
)
|
||||||
|
selected_bills = session.scalar(
|
||||||
|
select(func.count()).select_from(
|
||||||
|
create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
).subquery()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TopicExtractionDiagnostics(
|
||||||
|
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||||
|
bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=bill_text_filters,
|
||||||
|
),
|
||||||
|
summarized_bill_text_rows=_count_bill_texts(
|
||||||
|
session,
|
||||||
|
bill_filters=bill_filters,
|
||||||
|
bill_text_filters=summary_filters,
|
||||||
|
),
|
||||||
|
bills_with_summaries=bills_with_summaries or 0,
|
||||||
|
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||||
|
selected_bills=selected_bills or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||||
|
text = response_text.strip()
|
||||||
|
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
text = fenced.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
msg = "Topic extraction response must be a JSON object"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_confidence(raw: Any) -> float | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
msg = f"Invalid confidence: {raw!r}"
|
||||||
|
raise TopicExtractionError(msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||||
|
if topic.confidence is None:
|
||||||
|
return (0, 0.0)
|
||||||
|
return (1, topic.confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||||
|
raw_topic = item.get("topic")
|
||||||
|
if isinstance(raw_topic, str):
|
||||||
|
return raw_topic
|
||||||
|
if isinstance(raw_topic, dict):
|
||||||
|
for key in ("topic", "label", "name", "title"):
|
||||||
|
value = raw_topic.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
msg = "Topic extraction response topic must be a string"
|
||||||
|
raise TopicExtractionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_bill_texts(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
bill_filters: Sequence[ColumnElement[bool]],
|
||||||
|
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||||
|
) -> int:
|
||||||
|
stmt = select(func.count(BillText.id))
|
||||||
|
if bill_filters:
|
||||||
|
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||||
|
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
topics_path: Annotated[
|
||||||
|
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||||
|
] = DEFAULT_TOPICS_PATH,
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||||
|
] = False,
|
||||||
|
diagnose: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Log input-stage counts before processing."),
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill topics."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_catalog = load_topic_catalog(topics_path)
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d candidate topics from %s",
|
||||||
|
len(topic_catalog.candidate_topics),
|
||||||
|
topics_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
if diagnose or dry_run:
|
||||||
|
diagnostics = collect_topic_extraction_diagnostics(
|
||||||
|
session,
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
_log_topic_extraction_diagnostics(diagnostics)
|
||||||
|
if dry_run:
|
||||||
|
return
|
||||||
|
|
||||||
|
openai_config = get_openai_config()
|
||||||
|
|
||||||
|
stmt = create_select_bills_for_topic_extraction(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bills = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
failed = 0
|
||||||
|
for index, bill in enumerate(bills, 1):
|
||||||
|
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||||
|
if bill_text is None:
|
||||||
|
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||||
|
continue
|
||||||
|
summary = bill_text.summary.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_topics = extract_topics_for_bill_text(
|
||||||
|
openai_config=openai_config,
|
||||||
|
bill=bill,
|
||||||
|
text=summary,
|
||||||
|
candidate_topics=topic_catalog.candidate_topics,
|
||||||
|
)
|
||||||
|
except (httpx.HTTPError, TopicExtractionError):
|
||||||
|
failed += 1
|
||||||
|
logger.exception(
|
||||||
|
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
store_bill_topic_result(
|
||||||
|
session=session,
|
||||||
|
bill=bill,
|
||||||
|
topics=extracted_topics,
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
written += 1
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Stored %d topics for bill id=%s",
|
||||||
|
len(extracted_topics),
|
||||||
|
bill.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Done: stored topic results for %d bills; failed %d bills",
|
||||||
|
written,
|
||||||
|
failed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_topic_extraction_diagnostics(
|
||||||
|
diagnostics: TopicExtractionDiagnostics,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||||
|
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||||
|
"bill_topic_rows=%d selected_bills=%d",
|
||||||
|
diagnostics.bill_rows,
|
||||||
|
diagnostics.bill_text_rows,
|
||||||
|
diagnostics.summarized_bill_text_rows,
|
||||||
|
diagnostics.bills_with_summaries,
|
||||||
|
diagnostics.bill_topic_rows,
|
||||||
|
diagnostics.selected_bills,
|
||||||
|
)
|
||||||
|
if diagnostics.bill_rows == 0:
|
||||||
|
logger.warning("No bills matched the topic extraction scope.")
|
||||||
|
elif diagnostics.bill_text_rows == 0:
|
||||||
|
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||||
|
elif diagnostics.summarized_bill_text_rows == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No summarized bill_text rows matched the topic extraction scope. "
|
||||||
|
"Run pipelines.tools.summarize_bills first."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||||
|
logger.warning(
|
||||||
|
"No bills selected because matching bills already have topics. "
|
||||||
|
"Use --force to regenerate them."
|
||||||
|
)
|
||||||
|
elif diagnostics.selected_bills == 0:
|
||||||
|
logger.warning("No bills selected for topic extraction.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,281 @@
|
|||||||
|
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ingest-posts /path/to/files/
|
||||||
|
ingest-posts /path/to/single_file.jsonl
|
||||||
|
ingest-posts /data/dir/ --workers 4 --batch-size 5000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path # noqa: TC003 this is needed for typer
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import psycopg
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pipelines.pipelines.common import configure_logger
|
||||||
|
from pipelines.orm.common import get_connection_info
|
||||||
|
from pipelines.pipelines.parallelize import parallelize_process
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Argument(help="Directory containing JSONL files, or a single JSONL file"),
|
||||||
|
],
|
||||||
|
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
|
||||||
|
workers: Annotated[
|
||||||
|
int, typer.Option(help="Parallel workers for multi-file ingestion")
|
||||||
|
] = 4,
|
||||||
|
pattern: Annotated[
|
||||||
|
str, typer.Option(help="Glob pattern for JSONL files")
|
||||||
|
] = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest JSONL post files into the weekly-partitioned posts table."""
|
||||||
|
configure_logger(level="INFO")
|
||||||
|
|
||||||
|
logger.info("starting ingest-posts")
|
||||||
|
logger.info(
|
||||||
|
"path=%s batch_size=%d workers=%d pattern=%s",
|
||||||
|
path,
|
||||||
|
batch_size,
|
||||||
|
workers,
|
||||||
|
pattern,
|
||||||
|
)
|
||||||
|
if path.is_file():
|
||||||
|
ingest_file(path, batch_size=batch_size)
|
||||||
|
elif path.is_dir():
|
||||||
|
ingest_directory(
|
||||||
|
path, batch_size=batch_size, max_workers=workers, pattern=pattern
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
typer.echo(f"Path does not exist: {path}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
logger.info("ingest-posts done")
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_directory(
|
||||||
|
directory: Path,
|
||||||
|
*,
|
||||||
|
batch_size: int,
|
||||||
|
max_workers: int,
|
||||||
|
pattern: str = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest all JSONL files in a directory using parallel workers."""
|
||||||
|
files = sorted(directory.glob(pattern))
|
||||||
|
if not files:
|
||||||
|
logger.warning("No JSONL files found in %s", directory)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Found %d JSONL files to ingest", len(files))
|
||||||
|
|
||||||
|
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
|
||||||
|
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA = "main"
|
||||||
|
|
||||||
|
COLUMNS = (
|
||||||
|
"post_id",
|
||||||
|
"user_id",
|
||||||
|
"instance",
|
||||||
|
"date",
|
||||||
|
"text",
|
||||||
|
"langs",
|
||||||
|
"like_count",
|
||||||
|
"reply_count",
|
||||||
|
"repost_count",
|
||||||
|
"reply_to",
|
||||||
|
"replied_author",
|
||||||
|
"thread_root",
|
||||||
|
"thread_root_author",
|
||||||
|
"repost_from",
|
||||||
|
"reposted_author",
|
||||||
|
"quotes",
|
||||||
|
"quoted_author",
|
||||||
|
"labels",
|
||||||
|
"sent_label",
|
||||||
|
"sent_score",
|
||||||
|
)
|
||||||
|
|
||||||
|
INSERT_FROM_STAGING = f"""
|
||||||
|
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
|
||||||
|
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
|
||||||
|
ON CONFLICT (post_id, date) DO NOTHING
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
FAILED_INSERT = f"""
|
||||||
|
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
|
||||||
|
VALUES (%(raw_line)s, %(error)s)
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
|
||||||
|
def get_psycopg_connection() -> psycopg.Connection:
|
||||||
|
"""Create a raw psycopg3 connection from environment variables."""
|
||||||
|
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
|
||||||
|
return psycopg.connect(
|
||||||
|
dbname=database,
|
||||||
|
host=host,
|
||||||
|
port=int(port),
|
||||||
|
user=username,
|
||||||
|
password=password,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_file(path: Path, *, batch_size: int) -> None:
|
||||||
|
"""Ingest a single JSONL file into the posts table."""
|
||||||
|
log_trigger = max(100_000 // batch_size, 1)
|
||||||
|
failed_lines: list[dict] = []
|
||||||
|
try:
|
||||||
|
with get_psycopg_connection() as connection:
|
||||||
|
for index, batch in enumerate(
|
||||||
|
read_jsonl_batches(path, batch_size, failed_lines), 1
|
||||||
|
):
|
||||||
|
ingest_batch(connection, batch)
|
||||||
|
if index % log_trigger == 0:
|
||||||
|
logger.info(
|
||||||
|
"Ingested %d batches (%d rows) from %s",
|
||||||
|
index,
|
||||||
|
index * batch_size,
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_lines:
|
||||||
|
logger.warning(
|
||||||
|
"Recording %d malformed lines from %s", len(failed_lines), path.name
|
||||||
|
)
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.executemany(FAILED_INSERT, failed_lines)
|
||||||
|
connection.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to ingest file: %s", path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
|
||||||
|
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(f"""
|
||||||
|
CREATE TEMP TABLE IF NOT EXISTS staging
|
||||||
|
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
|
||||||
|
ON COMMIT DELETE ROWS
|
||||||
|
""")
|
||||||
|
cursor.execute("TRUNCATE pg_temp.staging")
|
||||||
|
|
||||||
|
with cursor.copy(
|
||||||
|
f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN"
|
||||||
|
) as copy:
|
||||||
|
for row in batch:
|
||||||
|
copy.write_row(tuple(row.get(column) for column in COLUMNS))
|
||||||
|
|
||||||
|
cursor.execute(INSERT_FROM_STAGING)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as error:
|
||||||
|
connection.rollback()
|
||||||
|
|
||||||
|
if len(batch) == 1:
|
||||||
|
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
FAILED_INSERT,
|
||||||
|
{
|
||||||
|
"raw_line": orjson.dumps(batch[0], default=str).decode(),
|
||||||
|
"error": str(error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
midpoint = len(batch) // 2
|
||||||
|
ingest_batch(connection, batch[:midpoint])
|
||||||
|
ingest_batch(connection, batch[midpoint:])
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl_batches(
|
||||||
|
file_path: Path, batch_size: int, failed_lines: list[dict]
|
||||||
|
) -> Iterator[list[dict]]:
|
||||||
|
"""Stream a JSONL file and yield batches of transformed rows."""
|
||||||
|
batch: list[dict] = []
|
||||||
|
with file_path.open("r", encoding="utf-8") as handle:
|
||||||
|
for raw_line in handle:
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
batch.extend(parse_line(line, file_path, failed_lines))
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
|
||||||
|
"""Parse a JSONL line, handling concatenated JSON objects."""
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(line))
|
||||||
|
except orjson.JSONDecodeError:
|
||||||
|
if "}{" not in line:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping malformed line in %s: %s", file_path.name, line[:120]
|
||||||
|
)
|
||||||
|
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
|
||||||
|
return
|
||||||
|
fragments = line.replace("}{", "}\n{").split("\n")
|
||||||
|
for fragment in fragments:
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(fragment))
|
||||||
|
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping malformed fragment in %s: %s",
|
||||||
|
file_path.name,
|
||||||
|
fragment[:120],
|
||||||
|
)
|
||||||
|
failed_lines.append({"raw_line": fragment, "error": str(error)})
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
|
||||||
|
failed_lines.append({"raw_line": line, "error": str(error)})
|
||||||
|
|
||||||
|
|
||||||
|
def transform_row(raw: dict) -> dict:
|
||||||
|
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
|
||||||
|
raw["date"] = parse_date(raw["date"])
|
||||||
|
if raw.get("langs") is not None:
|
||||||
|
raw["langs"] = orjson.dumps(raw["langs"])
|
||||||
|
if raw.get("text") is not None:
|
||||||
|
raw["text"] = raw["text"].replace("\x00", "")
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def parse_date(raw_date: int) -> datetime:
|
||||||
|
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
|
||||||
|
return datetime(
|
||||||
|
raw_date // 100000000,
|
||||||
|
(raw_date // 1000000) % 100,
|
||||||
|
(raw_date // 10000) % 100,
|
||||||
|
(raw_date // 100) % 100,
|
||||||
|
raw_date % 100,
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import Select, exists, or_, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from tiktoken import get_encoding
|
||||||
|
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
Bill,
|
||||||
|
BillText,
|
||||||
|
SubjectType,
|
||||||
|
VoteClassification,
|
||||||
|
VoteRelationship,
|
||||||
|
VoteTextTarget,
|
||||||
|
)
|
||||||
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def load_summarization_prompts(
|
||||||
|
section: str = "summarization",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
|
|
||||||
|
return tomllib.loads(summarization_prompts.read_text())[section]
|
||||||
|
|
||||||
|
|
||||||
|
class BillSummaryError(RuntimeError):
|
||||||
|
"""Raised when a bill summary request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_summary(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
msg = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
OPENAI_CHAT_COMPLETIONS_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
logger.info(f"{response.text=}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def build_bill_summary_messages(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build the GPT prompt messages plus compressed text and user prompt."""
|
||||||
|
if not bill_text.text_content:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
compressed_text = compress_bill_text(bill_text.text_content)
|
||||||
|
if not compressed_text:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
user_prompt = summarization_prompts["user_template"].format(
|
||||||
|
text_content=compressed_text
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
|
||||||
|
logger.info(f"{user_prompt_tokens=}")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return messages, user_prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bill_text(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
bill_text: BillText,
|
||||||
|
summarization_prompts: dict[str, str],
|
||||||
|
) -> str:
|
||||||
|
"""Generate and return a summary for one bill_text row."""
|
||||||
|
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
# This may only be for gpt-5.4 mini I need to read the docs
|
||||||
|
if user_prompt_tokens > 272000:
|
||||||
|
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
|
||||||
|
logger.warning(msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
summary = call_openai_summary(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
).strip()
|
||||||
|
if not summary:
|
||||||
|
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_summary_result(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summary: str,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||||
|
bill_text.summary = summary
|
||||||
|
bill_text.summarization_model = model
|
||||||
|
bill_text.summarization_system_prompt_version = "v1.2"
|
||||||
|
bill_text.summarization_user_prompt_version = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bill_texts_for_summarization(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
with_votes_only: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Select bill_text rows that have source text and need summaries."""
|
||||||
|
stmt = (
|
||||||
|
select(BillText)
|
||||||
|
.join(Bill, Bill.id == BillText.bill_id)
|
||||||
|
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||||
|
.options(selectinload(BillText.bill))
|
||||||
|
.order_by(BillText.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||||
|
if with_votes_only:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists(
|
||||||
|
select(VoteTextTarget.vote_id)
|
||||||
|
.join(
|
||||||
|
VoteClassification,
|
||||||
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||||
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||||
|
VoteClassification.vote_relationship
|
||||||
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||||
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||||
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||||
|
VoteClassification.is_special_rule.is_(False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
with_votes_only: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--with-votes-only",
|
||||||
|
help="Only process bill_text rows linked to at least one vote.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill summaries."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
if not getenv("CLOSEDAI_TOKEN"):
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
summarization_prompts = load_summarization_prompts()
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = create_select_bill_texts_for_summarization(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
with_votes_only=with_votes_only,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bill_texts = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for index, bill_text in enumerate(bill_texts, 1):
|
||||||
|
summary = summarize_bill_text(
|
||||||
|
model=model,
|
||||||
|
bill_text=bill_text,
|
||||||
|
summarization_prompts=summarization_prompts,
|
||||||
|
)
|
||||||
|
if summary is None:
|
||||||
|
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||||
|
continue
|
||||||
|
store_bill_summary_result(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summary=summary,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
written += 1
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||||
|
|
||||||
|
logger.info("Done: stored %d summaries", written)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -23,14 +23,10 @@ import httpx
|
|||||||
import typer
|
import typer
|
||||||
from tiktoken import Encoding, get_encoding
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = (
|
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
Path(__file__).resolve().parents[2]
|
|
||||||
/ "config"
|
|
||||||
/ "prompts"
|
|
||||||
/ "summarization_prompts.toml"
|
|
||||||
)
|
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -24,14 +24,10 @@ from typing import Annotated
|
|||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = (
|
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
Path(__file__).resolve().parents[2]
|
|
||||||
/ "config"
|
|
||||||
/ "prompts"
|
|
||||||
/ "summarization_prompts.toml"
|
|
||||||
)
|
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
from pipelines.config import default_config_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -123,7 +125,7 @@ def main(
|
|||||||
config_path: Annotated[
|
config_path: Annotated[
|
||||||
Path,
|
Path,
|
||||||
typer.Option("--config", help="TOML config file"),
|
typer.Option("--config", help="TOML config file"),
|
||||||
] = Path(__file__).parent / "config.toml",
|
] = default_config_path(),
|
||||||
save_gguf: Annotated[
|
save_gguf: Annotated[
|
||||||
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
||||||
] = False,
|
] = False,
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from pipelines.tools.containers.lib import check_gpu_free
|
from pipelines.containers.lib import check_gpu_free
|
||||||
from pipelines.tools.containers.vllm import start_vllm, stop_vllm
|
from pipelines.containers.vllm import start_vllm, stop_vllm
|
||||||
from pipelines.tools.downloader import is_model_present
|
from pipelines.tools.downloader import is_model_present
|
||||||
from pipelines.tools.models import BenchmarkConfig
|
from pipelines.tools.models import BenchmarkConfig
|
||||||
from pipelines.tools.vllm_client import VLLMClient
|
from pipelines.tools.vllm_client import VLLMClient
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
|
|
||||||
|
|
||||||
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
|
|
||||||
|
|
||||||
EXTRACTION RULES:
|
|
||||||
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
|
|
||||||
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
|
|
||||||
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
|
|
||||||
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
|
|
||||||
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
|
|
||||||
|
|
||||||
OUTPUT FORMAT — plain structured text, not JSON:
|
|
||||||
|
|
||||||
OPERATIVE ACTIONS:
|
|
||||||
[Numbered list of what the bill actually does, one action per line, max 20 words each]
|
|
||||||
|
|
||||||
AFFECTED POPULATIONS:
|
|
||||||
[Who gains something, who loses something, or whose behavior is regulated]
|
|
||||||
|
|
||||||
MECHANISMS:
|
|
||||||
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
|
|
||||||
|
|
||||||
POLICY THREADS:
|
|
||||||
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
|
|
||||||
|
|
||||||
SYMBOLIC/PROCEDURAL ONLY:
|
|
||||||
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
|
|
||||||
|
|
||||||
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
|
|
||||||
|
|
||||||
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
|
|
||||||
|
|
||||||
BILL TEXT:
|
|
||||||
{text_content}"""
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""FastAPI HTMX front end for the legislative database."""
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
"""WorkOS AuthKit helpers for the FastAPI web app."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from os import getenv
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from workos import WorkOSClient
|
||||||
|
from workos.session import seal_session_from_auth_response
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AuthConfig:
|
||||||
|
"""Runtime configuration for WorkOS AuthKit."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
client_id: str
|
||||||
|
cookie_password: str
|
||||||
|
redirect_uri: str
|
||||||
|
logout_redirect_uri: str
|
||||||
|
session_cookie_name: str
|
||||||
|
organization_id: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def secure_cookies(self) -> bool:
|
||||||
|
return self.redirect_uri.startswith("https://")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AuthSession:
|
||||||
|
"""Normalized auth session passed through the app."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
email: str
|
||||||
|
first_name: str | None
|
||||||
|
last_name: str | None
|
||||||
|
role_slugs: set[str]
|
||||||
|
organization_id: str | None
|
||||||
|
raw_user: Any
|
||||||
|
raw_session: Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
parts = [part for part in (self.first_name, self.last_name) if part]
|
||||||
|
return " ".join(parts) if parts else self.email
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_admin(self) -> bool:
|
||||||
|
return "admin" in self.role_slugs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CallbackResult:
|
||||||
|
"""Result of exchanging a WorkOS callback code."""
|
||||||
|
|
||||||
|
sealed_session: str
|
||||||
|
next_path: str
|
||||||
|
|
||||||
|
|
||||||
|
def safe_next_path(value: str | None, default: str = "/dashboard") -> str:
|
||||||
|
"""Allow only local relative redirect targets."""
|
||||||
|
if value and value.startswith("/") and not value.startswith("//"):
|
||||||
|
return value
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def build_authorization_url(next_path: str) -> str:
|
||||||
|
"""Build the WorkOS hosted login URL."""
|
||||||
|
config = get_auth_config()
|
||||||
|
return get_workos_client().user_management.get_authorization_url(
|
||||||
|
provider="authkit",
|
||||||
|
redirect_uri=config.redirect_uri,
|
||||||
|
state=safe_next_path(next_path),
|
||||||
|
organization_id=config.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def exchange_code(request: Request) -> CallbackResult:
|
||||||
|
"""Exchange a WorkOS callback code for a sealed session cookie value."""
|
||||||
|
code = request.query_params.get("code")
|
||||||
|
if not code:
|
||||||
|
raise ValueError("Missing authentication code.")
|
||||||
|
|
||||||
|
config = get_auth_config()
|
||||||
|
auth_response = get_workos_client().user_management.authenticate_with_code(
|
||||||
|
code=code,
|
||||||
|
ip_address=_request_ip(request),
|
||||||
|
user_agent=request.headers.get("user-agent"),
|
||||||
|
)
|
||||||
|
sealed_session = seal_session_from_auth_response(
|
||||||
|
access_token=auth_response.access_token,
|
||||||
|
refresh_token=auth_response.refresh_token,
|
||||||
|
user=auth_response.user.to_dict(),
|
||||||
|
impersonator=auth_response.impersonator.to_dict()
|
||||||
|
if auth_response.impersonator is not None
|
||||||
|
else None,
|
||||||
|
cookie_password=config.cookie_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CallbackResult(
|
||||||
|
sealed_session=sealed_session,
|
||||||
|
next_path=safe_next_path(request.query_params.get("state")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_session(request: Request) -> AuthSession | None:
|
||||||
|
"""Load the current signed-in WorkOS session from the sealed cookie."""
|
||||||
|
cookie_name = getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session")
|
||||||
|
sealed_session = request.cookies.get(cookie_name)
|
||||||
|
if not sealed_session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = get_auth_config()
|
||||||
|
try:
|
||||||
|
session = get_workos_client().user_management.load_sealed_session(
|
||||||
|
session_data=sealed_session,
|
||||||
|
cookie_password=config.cookie_password,
|
||||||
|
)
|
||||||
|
auth_response = session.authenticate()
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if not getattr(auth_response, "authenticated", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
user = auth_response.user or {}
|
||||||
|
organization_id = getattr(auth_response, "organization_id", None)
|
||||||
|
if config.organization_id and organization_id != config.organization_id:
|
||||||
|
return None
|
||||||
|
role_slugs = set(getattr(auth_response, "roles", None) or [])
|
||||||
|
role = getattr(auth_response, "role", None)
|
||||||
|
if role:
|
||||||
|
role_slugs.add(role)
|
||||||
|
|
||||||
|
return AuthSession(
|
||||||
|
user_id=_user_field(user, "id") or "",
|
||||||
|
email=_user_field(user, "email") or "",
|
||||||
|
first_name=_user_field(user, "first_name"),
|
||||||
|
last_name=_user_field(user, "last_name"),
|
||||||
|
role_slugs=role_slugs,
|
||||||
|
organization_id=organization_id,
|
||||||
|
raw_user=user,
|
||||||
|
raw_session=auth_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logout_url(request: Request) -> str:
|
||||||
|
"""Return the WorkOS logout URL for the current sealed session."""
|
||||||
|
config = get_auth_config()
|
||||||
|
sealed_session = request.cookies.get(config.session_cookie_name)
|
||||||
|
if not sealed_session:
|
||||||
|
return config.logout_redirect_uri
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = get_workos_client().user_management.load_sealed_session(
|
||||||
|
session_data=sealed_session,
|
||||||
|
cookie_password=config.cookie_password,
|
||||||
|
)
|
||||||
|
return session.get_logout_url(return_to=config.logout_redirect_uri)
|
||||||
|
except ValueError:
|
||||||
|
return config.logout_redirect_uri
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_auth_config() -> AuthConfig:
|
||||||
|
"""Load and validate WorkOS environment configuration."""
|
||||||
|
values = {
|
||||||
|
"WORKOS_API_KEY": getenv("WORKOS_API_KEY"),
|
||||||
|
"WORKOS_CLIENT_ID": getenv("WORKOS_CLIENT_ID"),
|
||||||
|
"WORKOS_COOKIE_PASSWORD": getenv("WORKOS_COOKIE_PASSWORD"),
|
||||||
|
"WORKOS_ORGANIZATION_ID": getenv("WORKOS_ORGANIZATION_ID"),
|
||||||
|
}
|
||||||
|
missing = [name for name, value in values.items() if not value]
|
||||||
|
if missing:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Missing WorkOS configuration: " + ", ".join(sorted(missing))
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthConfig(
|
||||||
|
api_key=values["WORKOS_API_KEY"] or "",
|
||||||
|
client_id=values["WORKOS_CLIENT_ID"] or "",
|
||||||
|
cookie_password=values["WORKOS_COOKIE_PASSWORD"] or "",
|
||||||
|
redirect_uri=getenv("WORKOS_REDIRECT_URI", "http://localhost:8000/callback"),
|
||||||
|
logout_redirect_uri=getenv("WORKOS_LOGOUT_REDIRECT_URI", "http://localhost:8000/"),
|
||||||
|
session_cookie_name=getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session"),
|
||||||
|
organization_id=values["WORKOS_ORGANIZATION_ID"] or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_workos_client():
|
||||||
|
"""Create and cache the WorkOS SDK client."""
|
||||||
|
config = get_auth_config()
|
||||||
|
return WorkOSClient(api_key=config.api_key, client_id=config.client_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_ip(request: Request) -> str | None:
|
||||||
|
if request.client is None:
|
||||||
|
return None
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
|
||||||
|
def _user_field(user: Any, key: str) -> Any:
|
||||||
|
if isinstance(user, dict):
|
||||||
|
return user.get(key)
|
||||||
|
return getattr(user, key, None)
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""Database access for the FastAPI web app."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_engine() -> Engine:
|
||||||
|
"""Return the lazily-created DATA_SCIENCE_DEV SQLAlchemy engine."""
|
||||||
|
return get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_database_connection() -> None:
|
||||||
|
"""Fail fast if the configured DATA_SCIENCE_DEV database is unavailable."""
|
||||||
|
with get_engine().connect():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_scope() -> Iterator[Session]:
|
||||||
|
"""Yield a SQLAlchemy session for a read-only request."""
|
||||||
|
with Session(get_engine()) as session:
|
||||||
|
yield session
|
||||||
@@ -0,0 +1,609 @@
|
|||||||
|
"""FastAPI app for the HTMX legislative dashboard."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
from os import getenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
|
||||||
|
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from pipelines.web import auth, repository
|
||||||
|
from pipelines.web.db import session_scope, validate_database_connection
|
||||||
|
from pipelines.web.repository import Chamber, RankingResult
|
||||||
|
from pipelines.web.scoring import normalize_issues
|
||||||
|
from pipelines.web.svg import render_compare_radar_svg, render_score_history_svg
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
|
TEMPLATES_DIR = BASE_DIR / "templates"
|
||||||
|
STATIC_DIR = BASE_DIR / "static"
|
||||||
|
|
||||||
|
templates = Jinja2Templates(directory=TEMPLATES_DIR)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_: FastAPI):
|
||||||
|
"""Validate database access when the CLI starts the web server."""
|
||||||
|
if getenv("PYTEST_CURRENT_TEST") is None:
|
||||||
|
validate_database_connection()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Nornsight Legislative Dashboard", lifespan=lifespan)
|
||||||
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DashboardState:
|
||||||
|
"""Dashboard query-string state."""
|
||||||
|
|
||||||
|
issues: list[str]
|
||||||
|
chamber: Chamber
|
||||||
|
congress: int | None
|
||||||
|
compare: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/healthz", response_class=PlainTextResponse)
|
||||||
|
def healthz() -> str:
|
||||||
|
"""Return a simple liveness response."""
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
def home(request: Request) -> Response:
|
||||||
|
"""Render the public home page."""
|
||||||
|
current_user = auth.get_current_session(request)
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"home.html",
|
||||||
|
{
|
||||||
|
**_auth_context(current_user),
|
||||||
|
"auth_error": request.query_params.get("auth_error") == "1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/login")
|
||||||
|
def login(request: Request) -> Response:
|
||||||
|
"""Start the WorkOS hosted login flow."""
|
||||||
|
next_path = auth.safe_next_path(request.query_params.get("next"))
|
||||||
|
current_user = auth.get_current_session(request)
|
||||||
|
if current_user is not None:
|
||||||
|
return RedirectResponse(next_path, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
return RedirectResponse(
|
||||||
|
auth.build_authorization_url(next_path),
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/callback")
|
||||||
|
def callback(request: Request) -> Response:
|
||||||
|
"""Exchange the WorkOS code for a sealed session cookie."""
|
||||||
|
try:
|
||||||
|
result = auth.exchange_code(request)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("WorkOS callback exchange failed.")
|
||||||
|
response = RedirectResponse("/?auth_error=1", status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
_delete_auth_cookie(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
config = auth.get_auth_config()
|
||||||
|
response = RedirectResponse(result.next_path, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
response.set_cookie(
|
||||||
|
config.session_cookie_name,
|
||||||
|
result.sealed_session,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
secure=config.secure_cookies,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/logout")
|
||||||
|
def logout(request: Request) -> Response:
|
||||||
|
"""End the WorkOS session and clear the local sealed session cookie."""
|
||||||
|
response = RedirectResponse(auth.get_logout_url(request), status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
_delete_auth_cookie(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def require_user(request: Request) -> auth.AuthSession:
|
||||||
|
"""Redirect unauthenticated users to the WorkOS sign-in flow."""
|
||||||
|
current_user = auth.get_current_session(request)
|
||||||
|
if current_user is not None:
|
||||||
|
return current_user
|
||||||
|
next_path = request.url.path
|
||||||
|
if request.url.query:
|
||||||
|
next_path = f"{next_path}?{request.url.query}"
|
||||||
|
login_url = request.url_for("login").include_query_params(next=next_path)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
headers={"Location": str(login_url)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(current_user: auth.AuthSession = Depends(require_user)) -> auth.AuthSession:
|
||||||
|
"""Restrict a route to WorkOS users with the admin role."""
|
||||||
|
if current_user.is_admin:
|
||||||
|
return current_user
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/dashboard", response_class=HTMLResponse)
|
||||||
|
def dashboard(
|
||||||
|
request: Request, current_user: auth.AuthSession = Depends(require_user)
|
||||||
|
) -> Response:
|
||||||
|
"""Render the full dashboard page."""
|
||||||
|
context = {**_auth_context(current_user), **_dashboard_context(request)}
|
||||||
|
if request.headers.get("hx-request") == "true":
|
||||||
|
return templates.TemplateResponse(request, "partials/_dashboard.html", context)
|
||||||
|
return templates.TemplateResponse(request, "dashboard.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/partials/dashboard", response_class=HTMLResponse)
|
||||||
|
def dashboard_partial(request: Request, _: auth.AuthSession = Depends(require_user)) -> Response:
|
||||||
|
"""Render the filter-dependent dashboard body."""
|
||||||
|
context = _dashboard_context(request)
|
||||||
|
return templates.TemplateResponse(request, "partials/_dashboard.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/partials/issues", response_class=HTMLResponse)
|
||||||
|
def issues_partial(request: Request, _: auth.AuthSession = Depends(require_user)) -> Response:
|
||||||
|
"""Render only issue filters."""
|
||||||
|
context = _dashboard_context(request)
|
||||||
|
return templates.TemplateResponse(request, "partials/_issue_filters.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/partials/rankings", response_class=HTMLResponse)
|
||||||
|
def rankings_partial(request: Request, _: auth.AuthSession = Depends(require_user)) -> Response:
|
||||||
|
"""Render only ranking panels."""
|
||||||
|
context = _dashboard_context(request)
|
||||||
|
return templates.TemplateResponse(request, "partials/_rankings.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/partials/chart", response_class=HTMLResponse)
|
||||||
|
def chart_partial(request: Request, _: auth.AuthSession = Depends(require_user)) -> Response:
|
||||||
|
"""Render only the SVG chart panel."""
|
||||||
|
context = _dashboard_context(request)
|
||||||
|
return templates.TemplateResponse(request, "partials/_chart.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/legislators", response_class=HTMLResponse)
|
||||||
|
def legislators(
|
||||||
|
request: Request, current_user: auth.AuthSession = Depends(require_user)
|
||||||
|
) -> Response:
|
||||||
|
"""Render the legislator profile/search page."""
|
||||||
|
context = {**_auth_context(current_user), **_legislators_context(request)}
|
||||||
|
return templates.TemplateResponse(request, "legislators.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/partials/legislator-suggestions", response_class=HTMLResponse)
|
||||||
|
def legislator_suggestions_partial(
|
||||||
|
request: Request, _: auth.AuthSession = Depends(require_user)
|
||||||
|
) -> Response:
|
||||||
|
"""Render legislator search suggestions for the HTMX typeahead."""
|
||||||
|
query = request.query_params.get("q", "").strip()
|
||||||
|
context: dict[str, Any] = {
|
||||||
|
"q": query if len(query) >= 2 else "",
|
||||||
|
"matches": [],
|
||||||
|
"build_legislator_url": _build_legislator_url,
|
||||||
|
}
|
||||||
|
if len(query) >= 2:
|
||||||
|
with session_scope() as session:
|
||||||
|
context["matches"] = repository.search_legislators(
|
||||||
|
session, query=query, limit=8
|
||||||
|
)
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request, "partials/_legislator_suggestions.html", context
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/compare", response_class=HTMLResponse)
|
||||||
|
def compare(
|
||||||
|
request: Request, current_user: auth.AuthSession = Depends(require_user)
|
||||||
|
) -> Response:
|
||||||
|
"""Render the legislator radar comparison page."""
|
||||||
|
context = {**_auth_context(current_user), **_compare_context(request)}
|
||||||
|
return templates.TemplateResponse(request, "compare.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/admin", response_class=HTMLResponse)
|
||||||
|
def admin_page(
|
||||||
|
request: Request, current_user: auth.AuthSession = Depends(require_admin)
|
||||||
|
) -> Response:
|
||||||
|
"""Render the admin-only placeholder page."""
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"admin.html",
|
||||||
|
{
|
||||||
|
**_auth_context(current_user),
|
||||||
|
"organization_id": auth.get_auth_config().organization_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dashboard_context(request: Request) -> dict[str, Any]:
|
||||||
|
state = _parse_state(request)
|
||||||
|
base_context: dict[str, Any] = {
|
||||||
|
"state": state,
|
||||||
|
"issues": state.issues,
|
||||||
|
"selected_issue_label": " + ".join(state.issues) if state.issues else "",
|
||||||
|
"chamber": state.chamber,
|
||||||
|
"congress": state.congress,
|
||||||
|
"latest_score_year": None,
|
||||||
|
"last_updated": None,
|
||||||
|
"suggestions": [],
|
||||||
|
"rankings": RankingResult(supportive=[], opposed=[]),
|
||||||
|
"compare": [],
|
||||||
|
"chart_svg": render_score_history_svg([]),
|
||||||
|
"chart_series": [],
|
||||||
|
"has_votes": False,
|
||||||
|
"has_scores": False,
|
||||||
|
"empty_message": "",
|
||||||
|
"build_url": _build_url,
|
||||||
|
"build_dashboard_partial_url": _build_dashboard_partial_url,
|
||||||
|
"toggle_compare": _toggle_compare,
|
||||||
|
}
|
||||||
|
with session_scope() as session:
|
||||||
|
congress = state.congress or repository.latest_congress(session)
|
||||||
|
base_context["congress"] = congress
|
||||||
|
base_context["has_scores"] = repository.has_scores(session)
|
||||||
|
base_context["latest_score_year"] = repository.latest_score_year(session)
|
||||||
|
base_context["last_updated"] = repository.latest_vote_date(session, congress)
|
||||||
|
base_context["suggestions"] = repository.issue_suggestions(
|
||||||
|
session, congress=congress
|
||||||
|
)
|
||||||
|
|
||||||
|
if not base_context["has_scores"]:
|
||||||
|
base_context["empty_message"] = (
|
||||||
|
"No legislator scores are loaded yet. Run the score calculator first."
|
||||||
|
)
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
if congress is None:
|
||||||
|
base_context["congress"] = "Computed"
|
||||||
|
|
||||||
|
if not state.issues:
|
||||||
|
base_context["empty_message"] = (
|
||||||
|
"Choose one or more issue areas to calculate roll-call support scores."
|
||||||
|
)
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
rankings = repository.get_rankings(
|
||||||
|
session,
|
||||||
|
issues=state.issues,
|
||||||
|
chamber=state.chamber,
|
||||||
|
congress=congress,
|
||||||
|
)
|
||||||
|
base_context["rankings"] = rankings
|
||||||
|
compare = state.compare or [row.legislator_id for row in rankings.supportive[:2]]
|
||||||
|
base_context["compare"] = compare
|
||||||
|
if not rankings.supportive and not rankings.opposed:
|
||||||
|
base_context["empty_message"] = "No matching roll-call votes."
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
history = repository.get_score_history(
|
||||||
|
session,
|
||||||
|
issues=state.issues,
|
||||||
|
chamber=state.chamber,
|
||||||
|
congress=congress,
|
||||||
|
legislator_ids=compare,
|
||||||
|
)
|
||||||
|
base_context["chart_series"] = history
|
||||||
|
base_context["chart_svg"] = render_score_history_svg(history)
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_state(request: Request) -> DashboardState:
|
||||||
|
query = request.query_params
|
||||||
|
chamber = query.get("chamber", "senate").lower()
|
||||||
|
if chamber not in {"house", "senate", "all"}:
|
||||||
|
chamber = "senate"
|
||||||
|
congress = _parse_int(query.get("congress"))
|
||||||
|
compare = [
|
||||||
|
value
|
||||||
|
for value in (_parse_int(raw) for raw in query.getlist("compare"))
|
||||||
|
if value is not None
|
||||||
|
]
|
||||||
|
return DashboardState(
|
||||||
|
issues=normalize_issues(query.getlist("issues")),
|
||||||
|
chamber=chamber, # type: ignore[arg-type]
|
||||||
|
congress=congress,
|
||||||
|
compare=compare,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _legislators_context(request: Request) -> dict[str, Any]:
|
||||||
|
query = request.query_params.get("q", "").strip()
|
||||||
|
legislator_id = _parse_int(request.query_params.get("legislator_id"))
|
||||||
|
selected_topic = request.query_params.get("topic", "").strip()
|
||||||
|
per_page = _parse_per_page(request.query_params.get("per_page"))
|
||||||
|
page = max(_parse_int(request.query_params.get("page")) or 1, 1)
|
||||||
|
base_context: dict[str, Any] = {
|
||||||
|
"q": query,
|
||||||
|
"profile": None,
|
||||||
|
"matches": [],
|
||||||
|
"result_count": 0,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"per_page_options": [10, 25, 50],
|
||||||
|
"total_pages": 1,
|
||||||
|
"previous_page": None,
|
||||||
|
"next_page": None,
|
||||||
|
"selected_topic": selected_topic,
|
||||||
|
"history_svg": render_score_history_svg([]),
|
||||||
|
"history_series": [],
|
||||||
|
"build_legislator_url": _build_legislator_url,
|
||||||
|
"build_legislator_search_url": _build_legislator_search_url,
|
||||||
|
}
|
||||||
|
with session_scope() as session:
|
||||||
|
result_count = repository.count_legislators(session, query=query) if query else 0
|
||||||
|
total_pages = max((result_count + per_page - 1) // per_page, 1)
|
||||||
|
if page > total_pages:
|
||||||
|
page = total_pages
|
||||||
|
base_context["page"] = page
|
||||||
|
matches = (
|
||||||
|
repository.search_legislators(
|
||||||
|
session,
|
||||||
|
query=query,
|
||||||
|
limit=per_page,
|
||||||
|
offset=(page - 1) * per_page,
|
||||||
|
)
|
||||||
|
if query
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
profile = repository.get_legislator_profile(
|
||||||
|
session, legislator_id=legislator_id, query=None
|
||||||
|
)
|
||||||
|
base_context["profile"] = profile
|
||||||
|
base_context["matches"] = matches
|
||||||
|
base_context["result_count"] = result_count
|
||||||
|
base_context["total_pages"] = total_pages
|
||||||
|
base_context["previous_page"] = page - 1 if page > 1 else None
|
||||||
|
base_context["next_page"] = page + 1 if page < total_pages else None
|
||||||
|
if profile is None:
|
||||||
|
return base_context
|
||||||
|
if not selected_topic:
|
||||||
|
if profile.bottom_topics:
|
||||||
|
selected_topic = profile.bottom_topics[0].topic
|
||||||
|
elif profile.top_topics:
|
||||||
|
selected_topic = profile.top_topics[0].topic
|
||||||
|
base_context["selected_topic"] = selected_topic
|
||||||
|
if selected_topic:
|
||||||
|
history = repository.get_single_legislator_history(
|
||||||
|
session,
|
||||||
|
legislator_id=profile.legislator.legislator_id,
|
||||||
|
topic=selected_topic,
|
||||||
|
)
|
||||||
|
base_context["history_series"] = history
|
||||||
|
base_context["history_svg"] = render_score_history_svg(history)
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_context(request: Request) -> dict[str, Any]:
|
||||||
|
selected_legislators = _parse_int_list(
|
||||||
|
request.query_params.getlist("legislator_id")
|
||||||
|
or request.query_params.getlist("compare")
|
||||||
|
)[:4]
|
||||||
|
topics = normalize_issues(
|
||||||
|
request.query_params.getlist("topic") or request.query_params.getlist("issues")
|
||||||
|
)[:8]
|
||||||
|
query = request.query_params.get("q", "").strip()
|
||||||
|
base_context: dict[str, Any] = {
|
||||||
|
"selected_legislators": selected_legislators,
|
||||||
|
"selected_legislator_options": [],
|
||||||
|
"topics": topics,
|
||||||
|
"q": query,
|
||||||
|
"series": [],
|
||||||
|
"radar_svg": render_compare_radar_svg([], []),
|
||||||
|
"legislator_options": [],
|
||||||
|
"topic_options": [],
|
||||||
|
"build_compare_url": _build_compare_url,
|
||||||
|
}
|
||||||
|
with session_scope() as session:
|
||||||
|
default_legislators, default_topics = repository.get_compare_defaults(session)
|
||||||
|
if not selected_legislators and not query:
|
||||||
|
selected_legislators = default_legislators[:3]
|
||||||
|
if not topics:
|
||||||
|
topics = default_topics[:6]
|
||||||
|
selected_legislator_options = repository.get_legislator_options(
|
||||||
|
session, selected_legislators
|
||||||
|
)
|
||||||
|
series = repository.get_compare_radar_series(
|
||||||
|
session, legislator_ids=selected_legislators, topics=topics
|
||||||
|
)
|
||||||
|
base_context.update(
|
||||||
|
{
|
||||||
|
"selected_legislators": selected_legislators,
|
||||||
|
"selected_legislator_options": selected_legislator_options,
|
||||||
|
"topics": topics,
|
||||||
|
"q": query,
|
||||||
|
"series": series,
|
||||||
|
"radar_svg": render_compare_radar_svg(topics, series),
|
||||||
|
"legislator_options": repository.search_legislators(
|
||||||
|
session, query=query or None, limit=12
|
||||||
|
),
|
||||||
|
"topic_options": repository.issue_suggestions(
|
||||||
|
session, congress=None, limit=12
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base_context
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_int(value: str | None) -> int | None:
|
||||||
|
if value is None or value == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_int_list(values: list[str]) -> list[int]:
|
||||||
|
parsed: list[int] = []
|
||||||
|
seen: set[int] = set()
|
||||||
|
for value in values:
|
||||||
|
integer = _parse_int(value)
|
||||||
|
if integer is not None and integer not in seen:
|
||||||
|
parsed.append(integer)
|
||||||
|
seen.add(integer)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_per_page(value: str | None) -> int:
|
||||||
|
parsed = _parse_int(value)
|
||||||
|
return parsed if parsed in {10, 25, 50} else 10
|
||||||
|
|
||||||
|
|
||||||
|
def _build_url(
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
issues: list[str] | None = None,
|
||||||
|
chamber: str | None = None,
|
||||||
|
congress: int | None = None,
|
||||||
|
compare: list[int] | None = None,
|
||||||
|
) -> str:
|
||||||
|
params: list[tuple[str, str]] = []
|
||||||
|
chosen_issues = (
|
||||||
|
issues
|
||||||
|
if issues is not None
|
||||||
|
else normalize_issues(request.query_params.getlist("issues"))
|
||||||
|
)
|
||||||
|
chosen_chamber = (
|
||||||
|
chamber
|
||||||
|
if chamber is not None
|
||||||
|
else request.query_params.get("chamber", "senate")
|
||||||
|
)
|
||||||
|
chosen_congress = (
|
||||||
|
congress
|
||||||
|
if congress is not None
|
||||||
|
else _parse_int(request.query_params.get("congress"))
|
||||||
|
)
|
||||||
|
chosen_compare = (
|
||||||
|
compare
|
||||||
|
if compare is not None
|
||||||
|
else [
|
||||||
|
value
|
||||||
|
for value in (
|
||||||
|
_parse_int(raw) for raw in request.query_params.getlist("compare")
|
||||||
|
)
|
||||||
|
if value is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for issue in chosen_issues:
|
||||||
|
params.append(("issues", issue))
|
||||||
|
params.append(("chamber", chosen_chamber))
|
||||||
|
if chosen_congress is not None:
|
||||||
|
params.append(("congress", str(chosen_congress)))
|
||||||
|
for legislator_id in chosen_compare:
|
||||||
|
params.append(("compare", str(legislator_id)))
|
||||||
|
if not params:
|
||||||
|
return "/dashboard"
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
return f"/dashboard?{urlencode(params, doseq=True)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dashboard_partial_url(
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
issues: list[str] | None = None,
|
||||||
|
chamber: str | None = None,
|
||||||
|
congress: int | None = None,
|
||||||
|
compare: list[int] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Return the HTMX endpoint matching the current dashboard query state."""
|
||||||
|
dashboard_url = _build_url(
|
||||||
|
request,
|
||||||
|
issues=issues,
|
||||||
|
chamber=chamber,
|
||||||
|
congress=congress,
|
||||||
|
compare=compare,
|
||||||
|
)
|
||||||
|
return dashboard_url.replace("/dashboard", "/partials/dashboard", 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _toggle_compare(compare: list[int], legislator_id: int) -> list[int]:
|
||||||
|
"""Return compare IDs with the legislator added or removed."""
|
||||||
|
if legislator_id in compare:
|
||||||
|
return [value for value in compare if value != legislator_id]
|
||||||
|
return [*compare, legislator_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_legislator_url(
|
||||||
|
*,
|
||||||
|
legislator_id: int | None = None,
|
||||||
|
q: str | None = None,
|
||||||
|
topic: str | None = None,
|
||||||
|
per_page: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
params: list[tuple[str, str]] = []
|
||||||
|
if legislator_id is not None:
|
||||||
|
params.append(("legislator_id", str(legislator_id)))
|
||||||
|
if q:
|
||||||
|
params.append(("q", q))
|
||||||
|
if topic:
|
||||||
|
params.append(("topic", topic))
|
||||||
|
if per_page in {10, 25, 50} and per_page != 10:
|
||||||
|
params.append(("per_page", str(per_page)))
|
||||||
|
return f"/legislators?{urlencode(params)}" if params else "/legislators"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_legislator_search_url(
|
||||||
|
*,
|
||||||
|
q: str,
|
||||||
|
per_page: int,
|
||||||
|
page: int = 1,
|
||||||
|
) -> str:
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
params: list[tuple[str, str]] = []
|
||||||
|
if q:
|
||||||
|
params.append(("q", q))
|
||||||
|
params.append(("per_page", str(per_page)))
|
||||||
|
if page > 1:
|
||||||
|
params.append(("page", str(page)))
|
||||||
|
return f"/legislators?{urlencode(params)}" if params else "/legislators"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_compare_url(
|
||||||
|
*,
|
||||||
|
legislator_ids: list[int],
|
||||||
|
topics: list[str],
|
||||||
|
q: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
params: list[tuple[str, str]] = []
|
||||||
|
for legislator_id in legislator_ids[:4]:
|
||||||
|
params.append(("legislator_id", str(legislator_id)))
|
||||||
|
for topic in topics[:8]:
|
||||||
|
params.append(("topic", topic))
|
||||||
|
if q:
|
||||||
|
params.append(("q", q))
|
||||||
|
return f"/compare?{urlencode(params, doseq=True)}" if params else "/compare"
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_context(current_user: auth.AuthSession | None) -> dict[str, Any]:
|
||||||
|
"""Shared template context for auth-aware navigation."""
|
||||||
|
return {
|
||||||
|
"is_authenticated": current_user is not None,
|
||||||
|
"is_admin": current_user.is_admin if current_user is not None else False,
|
||||||
|
"current_user_name": current_user.display_name if current_user is not None else "",
|
||||||
|
"current_user_email": current_user.email if current_user is not None else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_auth_cookie(response: Response) -> None:
|
||||||
|
"""Delete the sealed WorkOS session cookie."""
|
||||||
|
response.delete_cookie(getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session"))
|
||||||
@@ -0,0 +1,670 @@
|
|||||||
|
"""Congress database queries for the web dashboard."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from sqlalchemy import ColumnElement, Select, case, desc, false, func, or_, select, true
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.congress import (
|
||||||
|
BillTopic,
|
||||||
|
Legislator,
|
||||||
|
LegislatorScore,
|
||||||
|
Vote,
|
||||||
|
)
|
||||||
|
from pipelines.web.scoring import normalize_issues
|
||||||
|
|
||||||
|
Chamber = Literal["house", "senate", "all"]
|
||||||
|
|
||||||
|
STATE_ALIASES = {
|
||||||
|
"alabama": "AL",
|
||||||
|
"alaska": "AK",
|
||||||
|
"arizona": "AZ",
|
||||||
|
"arkansas": "AR",
|
||||||
|
"california": "CA",
|
||||||
|
"colorado": "CO",
|
||||||
|
"connecticut": "CT",
|
||||||
|
"delaware": "DE",
|
||||||
|
"florida": "FL",
|
||||||
|
"georgia": "GA",
|
||||||
|
"hawaii": "HI",
|
||||||
|
"idaho": "ID",
|
||||||
|
"illinois": "IL",
|
||||||
|
"indiana": "IN",
|
||||||
|
"iowa": "IA",
|
||||||
|
"kansas": "KS",
|
||||||
|
"kentucky": "KY",
|
||||||
|
"louisiana": "LA",
|
||||||
|
"maine": "ME",
|
||||||
|
"maryland": "MD",
|
||||||
|
"massachusetts": "MA",
|
||||||
|
"michigan": "MI",
|
||||||
|
"minnesota": "MN",
|
||||||
|
"mississippi": "MS",
|
||||||
|
"missouri": "MO",
|
||||||
|
"montana": "MT",
|
||||||
|
"nebraska": "NE",
|
||||||
|
"nevada": "NV",
|
||||||
|
"new hampshire": "NH",
|
||||||
|
"new jersey": "NJ",
|
||||||
|
"new mexico": "NM",
|
||||||
|
"new york": "NY",
|
||||||
|
"north carolina": "NC",
|
||||||
|
"north dakota": "ND",
|
||||||
|
"ohio": "OH",
|
||||||
|
"oklahoma": "OK",
|
||||||
|
"oregon": "OR",
|
||||||
|
"pennsylvania": "PA",
|
||||||
|
"rhode island": "RI",
|
||||||
|
"south carolina": "SC",
|
||||||
|
"south dakota": "SD",
|
||||||
|
"tennessee": "TN",
|
||||||
|
"texas": "TX",
|
||||||
|
"utah": "UT",
|
||||||
|
"vermont": "VT",
|
||||||
|
"virginia": "VA",
|
||||||
|
"washington": "WA",
|
||||||
|
"west virginia": "WV",
|
||||||
|
"wisconsin": "WI",
|
||||||
|
"wyoming": "WY",
|
||||||
|
"district of columbia": "DC",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RankingRow:
|
||||||
|
"""A legislator support score row."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
display_name: str
|
||||||
|
party: str | None
|
||||||
|
state: str | None
|
||||||
|
chamber: str | None
|
||||||
|
score: float | None
|
||||||
|
supportive: int
|
||||||
|
opposed: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return self.supportive + self.opposed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RankingResult:
|
||||||
|
"""Supportive and opposed ranking lists."""
|
||||||
|
|
||||||
|
supportive: list[RankingRow]
|
||||||
|
opposed: list[RankingRow]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TimePoint:
|
||||||
|
"""One yearly chart point."""
|
||||||
|
|
||||||
|
year: int
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChartSeries:
|
||||||
|
"""One legislator score-history series."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
label: str
|
||||||
|
party: str | None
|
||||||
|
state: str | None
|
||||||
|
points: list[TimePoint]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TopicScore:
|
||||||
|
"""Average score for one topic."""
|
||||||
|
|
||||||
|
topic: str
|
||||||
|
score: float
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LegislatorOption:
|
||||||
|
"""Compact legislator metadata for search and comparison controls."""
|
||||||
|
|
||||||
|
legislator_id: int
|
||||||
|
display_name: str
|
||||||
|
party: str | None
|
||||||
|
state: str | None
|
||||||
|
chamber: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LegislatorProfile:
|
||||||
|
"""Legislator metadata plus issue score summary."""
|
||||||
|
|
||||||
|
legislator: LegislatorOption
|
||||||
|
overall_score: float | None
|
||||||
|
serving_since: int | None
|
||||||
|
top_topics: list[TopicScore]
|
||||||
|
bottom_topics: list[TopicScore]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RadarSeries:
|
||||||
|
"""One legislator polygon for the compare radar chart."""
|
||||||
|
|
||||||
|
legislator: LegislatorOption
|
||||||
|
average_score: float | None
|
||||||
|
scores_by_topic: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
def latest_congress(session: Session) -> int | None:
|
||||||
|
"""Return the latest congress number in the vote table."""
|
||||||
|
return session.scalar(select(func.max(Vote.congress)))
|
||||||
|
|
||||||
|
|
||||||
|
def latest_vote_date(session: Session, congress: int | None = None) -> date | None:
|
||||||
|
"""Return the most recent vote date, optionally scoped to a congress."""
|
||||||
|
stmt = select(func.max(Vote.vote_date))
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Vote.congress == congress)
|
||||||
|
return session.scalar(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def latest_score_year(session: Session) -> int | None:
|
||||||
|
"""Return the latest year in the precomputed legislator score table."""
|
||||||
|
return session.scalar(select(func.max(LegislatorScore.year)))
|
||||||
|
|
||||||
|
|
||||||
|
def has_scores(session: Session) -> bool:
|
||||||
|
"""Return True when the database has at least one precomputed score."""
|
||||||
|
return session.scalar(select(LegislatorScore.id).limit(1)) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def issue_suggestions(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
congress: int | None,
|
||||||
|
limit: int = 12,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return common precomputed score topics for issue filter suggestions."""
|
||||||
|
stmt = (
|
||||||
|
select(LegislatorScore.topic, func.count(LegislatorScore.id).label("score_count"))
|
||||||
|
.where(LegislatorScore.topic != "")
|
||||||
|
.group_by(LegislatorScore.topic)
|
||||||
|
.order_by(desc("score_count"), LegislatorScore.topic)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
suggestions = [row[0] for row in session.execute(stmt).all()]
|
||||||
|
if suggestions:
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
fallback = (
|
||||||
|
select(BillTopic.topic, func.count(BillTopic.id).label("topic_count"))
|
||||||
|
.where(BillTopic.topic != "")
|
||||||
|
.group_by(BillTopic.topic)
|
||||||
|
.order_by(desc("topic_count"), BillTopic.topic)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return [row[0] for row in session.execute(fallback).all()]
|
||||||
|
|
||||||
|
|
||||||
|
def ranking_query(
|
||||||
|
*,
|
||||||
|
issues: list[str],
|
||||||
|
chamber: Chamber,
|
||||||
|
congress: int,
|
||||||
|
) -> Select:
|
||||||
|
"""Build the aggregate ranking query from precomputed scores."""
|
||||||
|
average_score = func.avg(LegislatorScore.score).label("score")
|
||||||
|
supportive = func.sum(case((LegislatorScore.score >= 50, 1), else_=0)).label(
|
||||||
|
"supportive"
|
||||||
|
)
|
||||||
|
opposed = func.sum(case((LegislatorScore.score < 50, 1), else_=0)).label("opposed")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
Legislator.current_chamber,
|
||||||
|
average_score,
|
||||||
|
supportive,
|
||||||
|
opposed,
|
||||||
|
)
|
||||||
|
.join(LegislatorScore, LegislatorScore.legislator_id == Legislator.id)
|
||||||
|
.where(_score_topic_match_condition(issues))
|
||||||
|
.group_by(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
Legislator.current_chamber,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if chamber != "all":
|
||||||
|
stmt = stmt.where(Legislator.current_chamber == _db_chamber(chamber))
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def get_rankings(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
issues: list[str],
|
||||||
|
chamber: Chamber,
|
||||||
|
congress: int,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> RankingResult:
|
||||||
|
"""Return top supportive and opposed legislators from precomputed scores."""
|
||||||
|
rows = [
|
||||||
|
_ranking_row(row)
|
||||||
|
for row in session.execute(
|
||||||
|
ranking_query(issues=issues, chamber=chamber, congress=congress)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scored = [row for row in rows if row.score is not None]
|
||||||
|
supportive = sorted(
|
||||||
|
scored, key=lambda row: (-float(row.score), -row.total, row.display_name)
|
||||||
|
)[:limit]
|
||||||
|
opposed = sorted(
|
||||||
|
scored, key=lambda row: (float(row.score), -row.total, row.display_name)
|
||||||
|
)[:limit]
|
||||||
|
return RankingResult(supportive=supportive, opposed=opposed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_score_history(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
issues: list[str],
|
||||||
|
chamber: Chamber,
|
||||||
|
congress: int,
|
||||||
|
legislator_ids: list[int],
|
||||||
|
) -> list[ChartSeries]:
|
||||||
|
"""Return yearly score history from precomputed scores."""
|
||||||
|
if not legislator_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
average_score = func.avg(LegislatorScore.score).label("score")
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
LegislatorScore.year,
|
||||||
|
average_score,
|
||||||
|
)
|
||||||
|
.join(LegislatorScore, LegislatorScore.legislator_id == Legislator.id)
|
||||||
|
.where(
|
||||||
|
Legislator.id.in_(legislator_ids),
|
||||||
|
_score_topic_match_condition(issues),
|
||||||
|
)
|
||||||
|
.group_by(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
LegislatorScore.year,
|
||||||
|
)
|
||||||
|
.order_by(Legislator.id, LegislatorScore.year)
|
||||||
|
)
|
||||||
|
if chamber != "all":
|
||||||
|
stmt = stmt.where(Legislator.current_chamber == _db_chamber(chamber))
|
||||||
|
|
||||||
|
by_legislator: dict[int, ChartSeries] = {}
|
||||||
|
for row in session.execute(stmt):
|
||||||
|
if row.score is None:
|
||||||
|
continue
|
||||||
|
series = by_legislator.setdefault(
|
||||||
|
row.id,
|
||||||
|
ChartSeries(
|
||||||
|
legislator_id=row.id,
|
||||||
|
label=_display_name(row.official_full_name, row.last_name),
|
||||||
|
party=row.current_party,
|
||||||
|
state=row.current_state,
|
||||||
|
points=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
series.points.append(TimePoint(year=int(row.year), score=float(row.score)))
|
||||||
|
return list(by_legislator.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _ranking_row(row: object) -> RankingRow:
|
||||||
|
return RankingRow(
|
||||||
|
legislator_id=row.id,
|
||||||
|
display_name=_display_name(row.official_full_name, row.last_name),
|
||||||
|
party=row.current_party,
|
||||||
|
state=row.current_state,
|
||||||
|
chamber=row.current_chamber,
|
||||||
|
score=float(row.score) if row.score is not None else None,
|
||||||
|
supportive=row.supportive or 0,
|
||||||
|
opposed=row.opposed or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _score_topic_match_condition(
|
||||||
|
issues: list[str] | tuple[str, ...],
|
||||||
|
) -> ColumnElement[bool]:
|
||||||
|
normalized = normalize_issues(list(issues))
|
||||||
|
if not normalized:
|
||||||
|
return false()
|
||||||
|
return or_(*(LegislatorScore.topic.ilike(f"%{issue}%") for issue in normalized))
|
||||||
|
|
||||||
|
|
||||||
|
def search_legislators(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
query: str | None = None,
|
||||||
|
limit: int = 12,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[LegislatorOption]:
|
||||||
|
"""Search ingested legislators, preferring those with computed scores."""
|
||||||
|
return [
|
||||||
|
_legislator_option(row)
|
||||||
|
for row in session.execute(
|
||||||
|
legislator_search_query(query=query, limit=limit, offset=offset)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def count_legislators(session: Session, *, query: str | None = None) -> int:
|
||||||
|
"""Return the total number of legislators matching a search query."""
|
||||||
|
return int(
|
||||||
|
session.scalar(
|
||||||
|
select(func.count(Legislator.id)).where(_legislator_search_condition(query))
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_legislator_options(
|
||||||
|
session: Session, legislator_ids: list[int]
|
||||||
|
) -> list[LegislatorOption]:
|
||||||
|
"""Return legislator options in the same order as the selected IDs."""
|
||||||
|
options = {
|
||||||
|
option.legislator_id: option
|
||||||
|
for option in (
|
||||||
|
_get_legislator_option(session, legislator_id)
|
||||||
|
for legislator_id in legislator_ids
|
||||||
|
)
|
||||||
|
if option is not None
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
options[legislator_id]
|
||||||
|
for legislator_id in legislator_ids
|
||||||
|
if legislator_id in options
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def legislator_search_query(
|
||||||
|
*,
|
||||||
|
query: str | None = None,
|
||||||
|
limit: int = 12,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> Select:
|
||||||
|
"""Build the legislator search query used by profile and compare controls."""
|
||||||
|
score_count = func.count(LegislatorScore.id).label("score_count")
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
Legislator.current_chamber,
|
||||||
|
score_count,
|
||||||
|
)
|
||||||
|
.outerjoin(LegislatorScore, LegislatorScore.legislator_id == Legislator.id)
|
||||||
|
.group_by(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.first_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
Legislator.current_chamber,
|
||||||
|
Legislator.bioguide_id,
|
||||||
|
)
|
||||||
|
.order_by(desc("score_count"), Legislator.last_name, Legislator.first_name)
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
return stmt.where(_legislator_search_condition(query))
|
||||||
|
|
||||||
|
|
||||||
|
def _legislator_search_condition(query: str | None) -> ColumnElement[bool]:
|
||||||
|
cleaned_query = query.strip() if query else ""
|
||||||
|
if not cleaned_query:
|
||||||
|
return true()
|
||||||
|
|
||||||
|
pattern = f"%{cleaned_query}%"
|
||||||
|
state_alias = _state_alias(cleaned_query)
|
||||||
|
conditions: list[ColumnElement[bool]] = [
|
||||||
|
Legislator.official_full_name.ilike(pattern),
|
||||||
|
Legislator.first_name.ilike(pattern),
|
||||||
|
Legislator.last_name.ilike(pattern),
|
||||||
|
Legislator.current_state.ilike(pattern),
|
||||||
|
Legislator.bioguide_id.ilike(pattern),
|
||||||
|
]
|
||||||
|
if state_alias is not None:
|
||||||
|
conditions.append(Legislator.current_state == state_alias)
|
||||||
|
return or_(*conditions)
|
||||||
|
|
||||||
|
|
||||||
|
def _state_alias(query: str) -> str | None:
|
||||||
|
normalized = " ".join(query.lower().replace(".", "").split())
|
||||||
|
if len(normalized) == 2 and normalized.isalpha():
|
||||||
|
return normalized.upper()
|
||||||
|
return STATE_ALIASES.get(normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def get_legislator_profile(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
legislator_id: int | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
) -> LegislatorProfile | None:
|
||||||
|
"""Return the selected legislator profile and top/bottom topic scores."""
|
||||||
|
selected = _get_legislator_option(session, legislator_id)
|
||||||
|
cleaned_query = query.strip() if query else ""
|
||||||
|
if selected is None and cleaned_query:
|
||||||
|
matches = search_legislators(session, query=query, limit=1)
|
||||||
|
selected = matches[0] if matches else None
|
||||||
|
if selected is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
topic_scores = get_legislator_topic_scores(
|
||||||
|
session, legislator_id=selected.legislator_id
|
||||||
|
)
|
||||||
|
top_topics = sorted(topic_scores, key=lambda item: (-item.score, item.topic))[:3]
|
||||||
|
bottom_topics = sorted(topic_scores, key=lambda item: (item.score, item.topic))[:3]
|
||||||
|
overall_score = session.scalar(
|
||||||
|
select(func.avg(LegislatorScore.score)).where(
|
||||||
|
LegislatorScore.legislator_id == selected.legislator_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
serving_since = session.scalar(
|
||||||
|
select(func.min(LegislatorScore.year)).where(
|
||||||
|
LegislatorScore.legislator_id == selected.legislator_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return LegislatorProfile(
|
||||||
|
legislator=selected,
|
||||||
|
overall_score=float(overall_score) if overall_score is not None else None,
|
||||||
|
serving_since=int(serving_since) if serving_since is not None else None,
|
||||||
|
top_topics=top_topics,
|
||||||
|
bottom_topics=bottom_topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_legislator_topic_scores(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
legislator_id: int,
|
||||||
|
) -> list[TopicScore]:
|
||||||
|
"""Return all average topic scores for one legislator."""
|
||||||
|
rows = session.execute(
|
||||||
|
select(
|
||||||
|
LegislatorScore.topic,
|
||||||
|
func.avg(LegislatorScore.score).label("score"),
|
||||||
|
func.count(LegislatorScore.id).label("count"),
|
||||||
|
)
|
||||||
|
.where(LegislatorScore.legislator_id == legislator_id)
|
||||||
|
.group_by(LegislatorScore.topic)
|
||||||
|
.order_by(LegislatorScore.topic)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
TopicScore(topic=row.topic, score=float(row.score), count=row.count)
|
||||||
|
for row in rows
|
||||||
|
if row.score is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_legislator_history(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
legislator_id: int,
|
||||||
|
topic: str,
|
||||||
|
) -> list[ChartSeries]:
|
||||||
|
"""Return score history for one legislator/topic pair."""
|
||||||
|
option = _get_legislator_option(session, legislator_id)
|
||||||
|
if option is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows = session.execute(
|
||||||
|
select(
|
||||||
|
LegislatorScore.year,
|
||||||
|
func.avg(LegislatorScore.score).label("score"),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
LegislatorScore.legislator_id == legislator_id,
|
||||||
|
LegislatorScore.topic == topic,
|
||||||
|
)
|
||||||
|
.group_by(LegislatorScore.year)
|
||||||
|
.order_by(LegislatorScore.year)
|
||||||
|
)
|
||||||
|
points = [
|
||||||
|
TimePoint(year=int(row.year), score=float(row.score))
|
||||||
|
for row in rows
|
||||||
|
if row.score is not None
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
ChartSeries(
|
||||||
|
legislator_id=option.legislator_id,
|
||||||
|
label=option.display_name,
|
||||||
|
party=option.party,
|
||||||
|
state=option.state,
|
||||||
|
points=points,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_compare_defaults(session: Session) -> tuple[list[int], list[str]]:
|
||||||
|
"""Return default compare legislators and topics."""
|
||||||
|
legislators = search_legislators(session, limit=3)
|
||||||
|
topics = issue_suggestions(session, congress=None, limit=6)
|
||||||
|
return [item.legislator_id for item in legislators], topics
|
||||||
|
|
||||||
|
|
||||||
|
def get_compare_radar_series(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
legislator_ids: list[int],
|
||||||
|
topics: list[str],
|
||||||
|
) -> list[RadarSeries]:
|
||||||
|
"""Return radar chart scores for selected legislators and topics."""
|
||||||
|
if not legislator_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
options = {
|
||||||
|
option.legislator_id: option
|
||||||
|
for option in (
|
||||||
|
_get_legislator_option(session, legislator_id)
|
||||||
|
for legislator_id in legislator_ids
|
||||||
|
)
|
||||||
|
if option is not None
|
||||||
|
}
|
||||||
|
if not options:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores: dict[int, dict[str, float]] = {
|
||||||
|
legislator_id: {} for legislator_id in options
|
||||||
|
}
|
||||||
|
if topics:
|
||||||
|
rows = session.execute(
|
||||||
|
select(
|
||||||
|
LegislatorScore.legislator_id,
|
||||||
|
LegislatorScore.topic,
|
||||||
|
func.avg(LegislatorScore.score).label("score"),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
LegislatorScore.legislator_id.in_(list(options)),
|
||||||
|
LegislatorScore.topic.in_(topics),
|
||||||
|
)
|
||||||
|
.group_by(LegislatorScore.legislator_id, LegislatorScore.topic)
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
scores[row.legislator_id][row.topic] = float(row.score)
|
||||||
|
|
||||||
|
series: list[RadarSeries] = []
|
||||||
|
for legislator_id in legislator_ids:
|
||||||
|
option = options.get(legislator_id)
|
||||||
|
if option is None:
|
||||||
|
continue
|
||||||
|
topic_scores = scores.get(legislator_id, {})
|
||||||
|
values = list(topic_scores.values())
|
||||||
|
series.append(
|
||||||
|
RadarSeries(
|
||||||
|
legislator=option,
|
||||||
|
average_score=sum(values) / len(values) if values else None,
|
||||||
|
scores_by_topic=topic_scores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return series
|
||||||
|
|
||||||
|
|
||||||
|
def _display_name(official_full_name: str | None, last_name: str | None) -> str:
|
||||||
|
if official_full_name:
|
||||||
|
parts = official_full_name.split()
|
||||||
|
if len(parts) > 1:
|
||||||
|
return f"{parts[-1]}, {' '.join(parts[:-1])}"
|
||||||
|
return official_full_name
|
||||||
|
return last_name or "Unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _legislator_option(row: object) -> LegislatorOption:
|
||||||
|
return LegislatorOption(
|
||||||
|
legislator_id=row.id,
|
||||||
|
display_name=_display_name(row.official_full_name, row.last_name),
|
||||||
|
party=row.current_party,
|
||||||
|
state=row.current_state,
|
||||||
|
chamber=row.current_chamber,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_legislator_option(
|
||||||
|
session: Session, legislator_id: int | None
|
||||||
|
) -> LegislatorOption | None:
|
||||||
|
if legislator_id is None:
|
||||||
|
return None
|
||||||
|
row = session.execute(
|
||||||
|
select(
|
||||||
|
Legislator.id,
|
||||||
|
Legislator.official_full_name,
|
||||||
|
Legislator.last_name,
|
||||||
|
Legislator.current_party,
|
||||||
|
Legislator.current_state,
|
||||||
|
Legislator.current_chamber,
|
||||||
|
).where(Legislator.id == legislator_id)
|
||||||
|
).first()
|
||||||
|
return _legislator_option(row) if row is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def _db_chamber(chamber: Chamber) -> str:
|
||||||
|
return {"house": "House", "senate": "Senate", "all": "all"}[chamber]
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
"""Issue matching and voting score helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqlalchemy import ColumnElement, false, func, or_
|
||||||
|
from sqlalchemy.sql.elements import BinaryExpression
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.congress import Bill, BillTopicPosition, Vote
|
||||||
|
|
||||||
|
SUPPORT_POSITIONS = frozenset({"yea", "aye", "yes"})
|
||||||
|
OPPOSE_POSITIONS = frozenset({"nay", "no"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScoreCounts:
|
||||||
|
"""Support/opposition counts for one legislator or time bucket."""
|
||||||
|
|
||||||
|
supportive: int
|
||||||
|
opposed: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return self.supportive + self.opposed
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_position(position: str | None) -> str | None:
|
||||||
|
"""Normalize a raw roll-call position into support/oppose/ignore buckets."""
|
||||||
|
if position is None:
|
||||||
|
return None
|
||||||
|
value = position.strip().lower()
|
||||||
|
if value in SUPPORT_POSITIONS:
|
||||||
|
return "support"
|
||||||
|
if value in OPPOSE_POSITIONS:
|
||||||
|
return "oppose"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def score_vote_position(
|
||||||
|
position: str | None,
|
||||||
|
support_position: BillTopicPosition | str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Score a raw vote as support/opposition for an extracted bill topic."""
|
||||||
|
normalized_vote = normalize_position(position)
|
||||||
|
if normalized_vote is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
topic_position = BillTopicPosition(support_position)
|
||||||
|
if topic_position is BillTopicPosition.FOR:
|
||||||
|
return normalized_vote
|
||||||
|
if normalized_vote == "support":
|
||||||
|
return "oppose"
|
||||||
|
return "support"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_score(counts: ScoreCounts) -> int | None:
|
||||||
|
"""Calculate the 0-100 support score, or None when there are no scored votes."""
|
||||||
|
if counts.total == 0:
|
||||||
|
return None
|
||||||
|
return round(100 * counts.supportive / counts.total)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_issues(issues: list[str] | tuple[str, ...]) -> list[str]:
|
||||||
|
"""Trim, de-duplicate, and preserve issue order for display and queries."""
|
||||||
|
normalized: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for issue in issues:
|
||||||
|
value = issue.strip()
|
||||||
|
key = value.casefold()
|
||||||
|
if value and key not in seen:
|
||||||
|
normalized.append(value)
|
||||||
|
seen.add(key)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def issue_match_condition(issues: list[str] | tuple[str, ...]) -> ColumnElement[bool]:
|
||||||
|
"""Build the SQLAlchemy condition for issue text matching."""
|
||||||
|
normalized = normalize_issues(list(issues))
|
||||||
|
if not normalized:
|
||||||
|
return false()
|
||||||
|
|
||||||
|
fields: tuple[ColumnElement[str | None], ...] = (
|
||||||
|
Bill.subjects_top_term,
|
||||||
|
Bill.title,
|
||||||
|
Bill.title_short,
|
||||||
|
Bill.official_title,
|
||||||
|
Vote.question,
|
||||||
|
Vote.result_text,
|
||||||
|
)
|
||||||
|
terms: list[BinaryExpression[bool]] = []
|
||||||
|
for issue in normalized:
|
||||||
|
pattern = f"%{issue}%"
|
||||||
|
terms.extend(field.ilike(pattern) for field in fields)
|
||||||
|
return or_(*terms)
|
||||||
|
|
||||||
|
|
||||||
|
def normalized_position_expression(column: ColumnElement[str]) -> ColumnElement[str | None]:
|
||||||
|
"""Lowercase and trim a SQL column containing raw vote positions."""
|
||||||
|
return func.lower(func.trim(column))
|
||||||
File diff suppressed because it is too large
Load Diff
+1
File diff suppressed because one or more lines are too long
@@ -0,0 +1,231 @@
|
|||||||
|
"""Inline SVG rendering helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from html import escape
|
||||||
|
from math import cos, pi, sin
|
||||||
|
|
||||||
|
from pipelines.web.repository import ChartSeries, RadarSeries
|
||||||
|
|
||||||
|
SERIES_STYLES = (
|
||||||
|
{
|
||||||
|
"color": "#009e73",
|
||||||
|
"dasharray": None,
|
||||||
|
"marker": "circle",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"color": "#0072b2",
|
||||||
|
"dasharray": "10 6",
|
||||||
|
"marker": "square",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"color": "#e69f00",
|
||||||
|
"dasharray": "4 5",
|
||||||
|
"marker": "diamond",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"color": "#cc79a7",
|
||||||
|
"dasharray": "14 5 3 5",
|
||||||
|
"marker": "triangle",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_score_history_svg(series: list[ChartSeries]) -> str:
|
||||||
|
"""Render a responsive inline SVG score history chart."""
|
||||||
|
width = 880
|
||||||
|
height = 300
|
||||||
|
margin_left = 70
|
||||||
|
margin_right = 28
|
||||||
|
margin_top = 24
|
||||||
|
margin_bottom = 48
|
||||||
|
plot_width = width - margin_left - margin_right
|
||||||
|
plot_height = height - margin_top - margin_bottom
|
||||||
|
|
||||||
|
all_years = sorted({point.year for item in series for point in item.points})
|
||||||
|
if not all_years:
|
||||||
|
return _empty_svg(width, height, "No score history for this selection")
|
||||||
|
|
||||||
|
min_year = min(all_years)
|
||||||
|
max_year = max(all_years)
|
||||||
|
year_span = max(max_year - min_year, 1)
|
||||||
|
|
||||||
|
def x_for(year: int) -> float:
|
||||||
|
return margin_left + ((year - min_year) / year_span) * plot_width
|
||||||
|
|
||||||
|
def y_for(score: int) -> float:
|
||||||
|
return margin_top + ((100 - score) / 100) * plot_height
|
||||||
|
|
||||||
|
parts: list[str] = [
|
||||||
|
f'<svg viewBox="0 0 {width} {height}" role="img" aria-label="Score history chart" class="score-chart">',
|
||||||
|
'<rect width="100%" height="100%" fill="transparent" />',
|
||||||
|
]
|
||||||
|
|
||||||
|
for score in (0, 25, 50, 75, 100):
|
||||||
|
y = y_for(score)
|
||||||
|
parts.append(
|
||||||
|
f'<line x1="{margin_left}" y1="{y:.2f}" x2="{width - margin_right}" y2="{y:.2f}" class="chart-grid" />'
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
f'<text x="{margin_left - 16}" y="{y + 4:.2f}" text-anchor="end" class="chart-axis-label">{score}</text>'
|
||||||
|
)
|
||||||
|
|
||||||
|
tick_years = _tick_years(all_years)
|
||||||
|
for year in tick_years:
|
||||||
|
x = x_for(year)
|
||||||
|
parts.append(
|
||||||
|
f'<line x1="{x:.2f}" y1="{margin_top}" x2="{x:.2f}" y2="{height - margin_bottom}" class="chart-year-line" />'
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
f'<text x="{x:.2f}" y="{height - 18}" text-anchor="middle" class="chart-axis-label">{year}</text>'
|
||||||
|
)
|
||||||
|
|
||||||
|
parts.append(
|
||||||
|
f'<line x1="{margin_left}" y1="{height - margin_bottom}" x2="{width - margin_right}" y2="{height - margin_bottom}" class="chart-axis" />'
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
f'<line x1="{margin_left}" y1="{margin_top}" x2="{margin_left}" y2="{height - margin_bottom}" class="chart-axis" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, item in enumerate(series):
|
||||||
|
points = sorted(item.points, key=lambda point: point.year)
|
||||||
|
if not points:
|
||||||
|
continue
|
||||||
|
style = SERIES_STYLES[index % len(SERIES_STYLES)]
|
||||||
|
color = style["color"]
|
||||||
|
path = " ".join(
|
||||||
|
f"{'M' if point_index == 0 else 'L'} {x_for(point.year):.2f} {y_for(point.score):.2f}"
|
||||||
|
for point_index, point in enumerate(points)
|
||||||
|
)
|
||||||
|
label = escape(item.label)
|
||||||
|
dash_attr = (
|
||||||
|
f' stroke-dasharray="{style["dasharray"]}"'
|
||||||
|
if style["dasharray"]
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
f'<path d="{path}" fill="none" stroke="{color}" stroke-width="3.5" stroke-linecap="round" stroke-linejoin="round"{dash_attr}>'
|
||||||
|
f"<title>{label}</title></path>"
|
||||||
|
)
|
||||||
|
for point in points:
|
||||||
|
parts.append(
|
||||||
|
_point_marker(
|
||||||
|
marker=style["marker"],
|
||||||
|
x=x_for(point.year),
|
||||||
|
y=y_for(point.score),
|
||||||
|
color=color,
|
||||||
|
label=f"{label}: {point.score:.0f} in {point.year}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last = points[-1]
|
||||||
|
parts.append(
|
||||||
|
f'<text x="{x_for(last.year) - 10:.2f}" y="{y_for(last.score) + 4:.2f}" text-anchor="end" class="chart-series-label" fill="{color}">'
|
||||||
|
f"{last.score:.0f}</text>"
|
||||||
|
)
|
||||||
|
|
||||||
|
parts.append("</svg>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_svg(width: int, height: int, message: str) -> str:
|
||||||
|
return (
|
||||||
|
f'<svg viewBox="0 0 {width} {height}" role="img" aria-label="{escape(message)}" class="score-chart">'
|
||||||
|
'<rect width="100%" height="100%" fill="transparent" />'
|
||||||
|
f'<text x="{width / 2}" y="{height / 2}" text-anchor="middle" class="chart-empty">{escape(message)}</text>'
|
||||||
|
"</svg>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tick_years(years: list[int]) -> list[int]:
|
||||||
|
first = years[0]
|
||||||
|
last = years[-1]
|
||||||
|
start = first - (first % 5)
|
||||||
|
tick_years = {year for year in range(start, last + 1, 5) if first <= year <= last}
|
||||||
|
tick_years.add(first)
|
||||||
|
tick_years.add(last)
|
||||||
|
return sorted(tick_years)
|
||||||
|
|
||||||
|
|
||||||
|
def render_compare_radar_svg(topics: list[str], series: list[RadarSeries]) -> str:
|
||||||
|
"""Render a server-side radar chart for legislator comparison."""
|
||||||
|
width = 720
|
||||||
|
height = 560
|
||||||
|
center_x = 285
|
||||||
|
center_y = 280
|
||||||
|
radius = 200
|
||||||
|
if len(topics) < 3 or not series:
|
||||||
|
return _empty_svg(width, height, "Choose at least 3 axes and 1 legislator")
|
||||||
|
|
||||||
|
axis_count = len(topics)
|
||||||
|
|
||||||
|
def point_for(index: int, score: float) -> tuple[float, float]:
|
||||||
|
angle = -pi / 2 + (2 * pi * index / axis_count)
|
||||||
|
distance = radius * max(0, min(score, 100)) / 100
|
||||||
|
return center_x + cos(angle) * distance, center_y + sin(angle) * distance
|
||||||
|
|
||||||
|
def ring_points(score: float) -> str:
|
||||||
|
return " ".join(
|
||||||
|
f"{point_for(index, score)[0]:.2f},{point_for(index, score)[1]:.2f}"
|
||||||
|
for index in range(axis_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
parts: list[str] = [
|
||||||
|
f'<svg viewBox="0 0 {width} {height}" role="img" aria-label="Compare legislators radar chart" class="radar-chart">',
|
||||||
|
'<rect width="100%" height="100%" fill="transparent" />',
|
||||||
|
]
|
||||||
|
for ring in (25, 50, 75, 100):
|
||||||
|
parts.append(f'<polygon points="{ring_points(ring)}" class="radar-ring" />')
|
||||||
|
for index, topic in enumerate(topics):
|
||||||
|
outer_x, outer_y = point_for(index, 100)
|
||||||
|
label_x, label_y = point_for(index, 113)
|
||||||
|
parts.append(
|
||||||
|
f'<line x1="{center_x}" y1="{center_y}" x2="{outer_x:.2f}" y2="{outer_y:.2f}" class="radar-axis" />'
|
||||||
|
)
|
||||||
|
anchor = "middle"
|
||||||
|
if label_x < center_x - 24:
|
||||||
|
anchor = "end"
|
||||||
|
elif label_x > center_x + 24:
|
||||||
|
anchor = "start"
|
||||||
|
parts.append(
|
||||||
|
f'<text x="{label_x:.2f}" y="{label_y:.2f}" text-anchor="{anchor}" class="radar-label">{escape(topic)}</text>'
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, item in enumerate(series):
|
||||||
|
color = SERIES_STYLES[index % len(SERIES_STYLES)]["color"]
|
||||||
|
points = " ".join(
|
||||||
|
f"{point_for(topic_index, item.scores_by_topic.get(topic, 50.0))[0]:.2f},"
|
||||||
|
f"{point_for(topic_index, item.scores_by_topic.get(topic, 50.0))[1]:.2f}"
|
||||||
|
for topic_index, topic in enumerate(topics)
|
||||||
|
)
|
||||||
|
label = escape(item.legislator.display_name)
|
||||||
|
parts.append(
|
||||||
|
f'<polygon points="{points}" fill="{color}" fill-opacity="0.14" stroke="{color}" stroke-width="3" class="radar-series">'
|
||||||
|
f"<title>{label}</title></polygon>"
|
||||||
|
)
|
||||||
|
parts.append("</svg>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _point_marker(*, marker: str, x: float, y: float, color: str, label: str) -> str:
|
||||||
|
title = f"<title>{escape(label)}</title>"
|
||||||
|
if marker == "square":
|
||||||
|
return (
|
||||||
|
f'<rect x="{x - 4.25:.2f}" y="{y - 4.25:.2f}" width="8.5" height="8.5" '
|
||||||
|
f'fill="{color}" rx="1.5" ry="1.5">{title}</rect>'
|
||||||
|
)
|
||||||
|
if marker == "diamond":
|
||||||
|
points = (
|
||||||
|
f"{x:.2f},{y - 5.2:.2f} "
|
||||||
|
f"{x + 5.2:.2f},{y:.2f} "
|
||||||
|
f"{x:.2f},{y + 5.2:.2f} "
|
||||||
|
f"{x - 5.2:.2f},{y:.2f}"
|
||||||
|
)
|
||||||
|
return f'<polygon points="{points}" fill="{color}">{title}</polygon>'
|
||||||
|
if marker == "triangle":
|
||||||
|
points = (
|
||||||
|
f"{x:.2f},{y - 5.5:.2f} "
|
||||||
|
f"{x + 5.5:.2f},{y + 4.5:.2f} "
|
||||||
|
f"{x - 5.5:.2f},{y + 4.5:.2f}"
|
||||||
|
)
|
||||||
|
return f'<polygon points="{points}" fill="{color}">{title}</polygon>'
|
||||||
|
return f'<circle cx="{x:.2f}" cy="{y:.2f}" r="4.5" fill="{color}">{title}</circle>'
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Admin Settings{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell">
|
||||||
|
<section class="page-heading stacked-heading">
|
||||||
|
<div>
|
||||||
|
<h1>Admin settings</h1>
|
||||||
|
<p>Admin-only operational controls for the Nornsight workspace.</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section class="admin-card">
|
||||||
|
<h2>WorkOS-managed access</h2>
|
||||||
|
<p>
|
||||||
|
Invitations, Google access, and role assignments are managed in the WorkOS dashboard.
|
||||||
|
This page confirms that app-level admin gating is active.
|
||||||
|
</p>
|
||||||
|
<dl class="admin-meta">
|
||||||
|
<div>
|
||||||
|
<dt>Workspace organization</dt>
|
||||||
|
<dd><code>{{ organization_id }}</code></dd>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<dt>Current administrator</dt>
|
||||||
|
<dd>{{ current_user_email }}</dd>
|
||||||
|
</div>
|
||||||
|
</dl>
|
||||||
|
<div class="admin-actions">
|
||||||
|
<a href="/dashboard">Return to dashboard</a>
|
||||||
|
<a href="https://dashboard.workos.com/" rel="noreferrer" target="_blank">Open WorkOS dashboard</a>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>{% block title %}Nornsight{% endblock %}</title>
|
||||||
|
<link rel="stylesheet" href="{{ url_for('static', path='styles.css') }}">
|
||||||
|
<script src="{{ url_for('static', path='vendor/htmx.min.js') }}" defer></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<header class="topbar">
|
||||||
|
<a class="brand" href="/">
|
||||||
|
<span class="brand-mark">N</span>
|
||||||
|
<span>Nornsight</span>
|
||||||
|
</a>
|
||||||
|
{% if show_primary_nav|default(true) %}
|
||||||
|
<nav class="primary-nav" aria-label="Primary">
|
||||||
|
{% if is_authenticated|default(false) %}
|
||||||
|
<a href="/dashboard">Dashboard</a>
|
||||||
|
<a href="/legislators">Legislators</a>
|
||||||
|
<a href="/compare">Compare</a>
|
||||||
|
{% if is_admin|default(false) %}
|
||||||
|
<a href="/admin">Admin</a>
|
||||||
|
{% endif %}
|
||||||
|
{% else %}
|
||||||
|
<a href="/">Overview</a>
|
||||||
|
{% endif %}
|
||||||
|
</nav>
|
||||||
|
{% endif %}
|
||||||
|
<nav class="account-nav" aria-label="Account">
|
||||||
|
<a href="#" aria-disabled="true">Help</a>
|
||||||
|
{% if is_authenticated|default(false) %}
|
||||||
|
<details class="account-menu">
|
||||||
|
<summary>{{ current_user_name or "My account" }}</summary>
|
||||||
|
<div class="account-menu-panel">
|
||||||
|
<span class="account-email">{{ current_user_email }}</span>
|
||||||
|
<a href="#" aria-disabled="true">Account settings</a>
|
||||||
|
<form action="/logout" method="post">
|
||||||
|
<button class="sign-out" type="submit">Sign out</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
{% else %}
|
||||||
|
<a class="sign-in" href="/login?next=/dashboard">Sign in</a>
|
||||||
|
{% endif %}
|
||||||
|
</nav>
|
||||||
|
</header>
|
||||||
|
{% block body %}{% endblock %}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Compare Legislators{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell">
|
||||||
|
<section class="page-heading stacked-heading">
|
||||||
|
<div>
|
||||||
|
<h1>Compare legislators</h1>
|
||||||
|
<p>Up to 4 legislators · up to 8 issue axes · each polygon = one legislator's full issue profile</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section class="compare-controls">
|
||||||
|
<form class="wide-search compare-search" action="/compare" method="get">
|
||||||
|
<label class="sr-only" for="compare-legislator-search">Search legislators</label>
|
||||||
|
{% for legislator_id in selected_legislators %}
|
||||||
|
<input type="hidden" name="legislator_id" value="{{ legislator_id }}">
|
||||||
|
{% endfor %}
|
||||||
|
{% for topic in topics %}
|
||||||
|
<input type="hidden" name="topic" value="{{ topic }}">
|
||||||
|
{% endfor %}
|
||||||
|
<input
|
||||||
|
id="compare-legislator-search"
|
||||||
|
type="search"
|
||||||
|
name="q"
|
||||||
|
value="{{ q }}"
|
||||||
|
placeholder="Search legislators to add"
|
||||||
|
autocomplete="off">
|
||||||
|
<button type="submit">Search</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<h2>Legislators ({{ selected_legislator_options|length }} / 4)</h2>
|
||||||
|
<div class="result-chips">
|
||||||
|
{% for legislator in selected_legislator_options %}
|
||||||
|
{% set without = selected_legislators | reject('equalto', legislator.legislator_id) | list %}
|
||||||
|
<a href="{{ build_compare_url(legislator_ids=without, topics=topics, q=q) }}"><span class="legend-dot dot-{{ loop.index0 }}"></span>{{ legislator.display_name }}{% if legislator.state %} — {{ legislator.state }}{% endif %} ×</a>
|
||||||
|
{% endfor %}
|
||||||
|
{% if selected_legislator_options|length < 4 %}
|
||||||
|
{% for option in legislator_options %}
|
||||||
|
{% if option.legislator_id not in selected_legislators %}
|
||||||
|
<a class="dashed-chip" href="{{ build_compare_url(legislator_ids=selected_legislators + [option.legislator_id], topics=topics, q=q) }}">+ {{ option.display_name }}</a>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>Issue axes ({{ topics|length }} / 8)</h2>
|
||||||
|
<div class="axis-chips">
|
||||||
|
{% for topic in topics %}
|
||||||
|
{% set without_topic = topics[:loop.index0] + topics[loop.index:] %}
|
||||||
|
<a href="{{ build_compare_url(legislator_ids=selected_legislators, topics=without_topic, q=q) }}">{{ topic }} ×</a>
|
||||||
|
{% endfor %}
|
||||||
|
{% if topics|length < 8 %}
|
||||||
|
{% for topic in topic_options %}
|
||||||
|
{% if topic not in topics %}
|
||||||
|
<a href="{{ build_compare_url(legislator_ids=selected_legislators, topics=topics + [topic], q=q) }}">{{ topic }}</a>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section class="compare-card">
|
||||||
|
<div class="radar-frame">{{ radar_svg | safe }}</div>
|
||||||
|
<aside class="compare-legend">
|
||||||
|
<h2>Legend</h2>
|
||||||
|
{% for item in series %}
|
||||||
|
<div class="legend-row">
|
||||||
|
<span class="legend-line line-{{ loop.index0 }}"></span>
|
||||||
|
<div>
|
||||||
|
<strong>{{ item.legislator.display_name }}</strong>
|
||||||
|
<small>{{ item.legislator.state or "US" }} · {{ item.legislator.party or "—" }} · avg {{ item.average_score|round(0) if item.average_score is not none else "—" }}</small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
<p>Outer ring = 100% support. Each axis is scored independently against full roll-call record.</p>
|
||||||
|
<p><em>Max 4 legislators · max 8 axes</em></p>
|
||||||
|
</aside>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
<footer class="footer">
|
||||||
|
<span>Actual record, not rhetoric</span>
|
||||||
|
<span>Source: congressional roll-call votes</span>
|
||||||
|
<span>Not affiliated with any political party or organization</span>
|
||||||
|
</footer>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Legislative Accountability{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell">
|
||||||
|
<section class="page-heading">
|
||||||
|
<div>
|
||||||
|
<h1>Legislative accountability</h1>
|
||||||
|
<p>US legislative accountability · precomputed legislator topic scores{% if latest_score_year %} through {{ latest_score_year }}{% endif %}</p>
|
||||||
|
</div>
|
||||||
|
<div class="heading-actions">
|
||||||
|
<span>{{ current_user_email }}</span>
|
||||||
|
<a href="#" aria-disabled="true">Methodology</a>
|
||||||
|
<a href="#" aria-disabled="true">Data sources</a>
|
||||||
|
<span>Last updated: {{ last_updated.strftime("%b %Y") if last_updated else "Unavailable" }}</span>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<div class="notice">Choose one or more score topics, then select lawmakers to compare computed records over time.</div>
|
||||||
|
|
||||||
|
<div id="dashboard-body">
|
||||||
|
{% include "partials/_dashboard.html" %}
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
<footer class="footer">
|
||||||
|
<span>Actual record, not rhetoric</span>
|
||||||
|
<span>Source: congressional roll-call votes</span>
|
||||||
|
<span>Not affiliated with any political party or organization</span>
|
||||||
|
</footer>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Nornsight | Legislative Accountability{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell home-shell">
|
||||||
|
{% if auth_error %}
|
||||||
|
<div class="notice auth-notice">Authentication failed. Try signing in again.</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<section class="hero-panel">
|
||||||
|
<div class="hero-copy">
|
||||||
|
<p class="eyebrow">Invite-only access</p>
|
||||||
|
<h1>Track legislative behavior with role-aware access and shared WorkOS sign-in.</h1>
|
||||||
|
<p class="hero-text">
|
||||||
|
Nornsight turns roll-call data into issue-level accountability views for your invited team.
|
||||||
|
Use the public home page as the front door, then move signed-in users into the dashboard,
|
||||||
|
legislator search, and comparison tools.
|
||||||
|
</p>
|
||||||
|
<div class="hero-actions">
|
||||||
|
{% if is_authenticated %}
|
||||||
|
<a class="hero-primary" href="/dashboard">Open dashboard</a>
|
||||||
|
{% if is_admin %}
|
||||||
|
<a class="hero-secondary" href="/admin">Admin settings</a>
|
||||||
|
{% endif %}
|
||||||
|
{% else %}
|
||||||
|
<a class="hero-primary" href="/login?next=/dashboard">Sign in</a>
|
||||||
|
<a class="hero-secondary" href="#access-model">How access works</a>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<aside class="hero-card">
|
||||||
|
<h2>Launch access model</h2>
|
||||||
|
<ul>
|
||||||
|
<li>Public landing page at <code>/</code></li>
|
||||||
|
<li>Invite-only AuthKit login with Email + Password and Google</li>
|
||||||
|
<li><code>viewer</code> role for dashboard, legislators, and compare</li>
|
||||||
|
<li><code>admin</code> role for settings and account administration</li>
|
||||||
|
</ul>
|
||||||
|
</aside>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section id="access-model" class="home-grid">
|
||||||
|
<article class="home-card">
|
||||||
|
<h2>For invited users</h2>
|
||||||
|
<p>View the dashboard, inspect legislator profiles, and compare issue scoring without sharing a local password.</p>
|
||||||
|
</article>
|
||||||
|
<article class="home-card">
|
||||||
|
<h2>For admins</h2>
|
||||||
|
<p>Manage invitations and role assignments in WorkOS while the app enforces role-based route access.</p>
|
||||||
|
</article>
|
||||||
|
<article class="home-card">
|
||||||
|
<h2>For rollout</h2>
|
||||||
|
<p>Authentication is centralized, sessions are sealed, and the old hard-coded admin login is removed.</p>
|
||||||
|
</article>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Legislator Profiles{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell">
|
||||||
|
<section class="page-heading stacked-heading">
|
||||||
|
<div>
|
||||||
|
<h1>Legislator profiles</h1>
|
||||||
|
<p>Full issue taxonomy · search any current legislator</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<form class="wide-search legislator-search-form" action="/legislators" method="get">
|
||||||
|
<label class="sr-only" for="legislator-search">Search legislators</label>
|
||||||
|
<input
|
||||||
|
id="legislator-search"
|
||||||
|
type="search"
|
||||||
|
name="q"
|
||||||
|
value="{{ q }}"
|
||||||
|
placeholder="Search by name or state"
|
||||||
|
autocomplete="off"
|
||||||
|
hx-get="/partials/legislator-suggestions"
|
||||||
|
hx-trigger="input changed delay:200ms, search"
|
||||||
|
hx-target="#legislator-suggestions"
|
||||||
|
hx-swap="innerHTML">
|
||||||
|
<label class="sr-only" for="legislator-page-size">Results per page</label>
|
||||||
|
<select id="legislator-page-size" name="per_page" aria-label="Results per page">
|
||||||
|
{% for option in per_page_options %}
|
||||||
|
<option value="{{ option }}" {{ "selected" if option == per_page else "" }}>{{ option }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
<button type="submit">Search</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div id="legislator-suggestions" aria-live="polite"></div>
|
||||||
|
|
||||||
|
{% if q %}
|
||||||
|
<section class="phonebook-results" aria-label="Matching legislators">
|
||||||
|
<header>
|
||||||
|
<h2>Matching legislators</h2>
|
||||||
|
<span>{{ result_count }} result{{ "" if result_count == 1 else "s" }}</span>
|
||||||
|
</header>
|
||||||
|
{% if matches %}
|
||||||
|
<ol class="phonebook-list" start="{{ ((page - 1) * per_page) + 1 }}">
|
||||||
|
{% for option in matches %}
|
||||||
|
<li>
|
||||||
|
<a href="{{ build_legislator_url(legislator_id=option.legislator_id, q=q, per_page=per_page) }}">
|
||||||
|
<span class="phonebook-name">{{ option.display_name }}</span>
|
||||||
|
<span class="phonebook-meta">
|
||||||
|
{{ option.state or "US" }}{% if option.party %} · {{ option.party }}{% endif %}{% if option.chamber %} · {{ option.chamber }}{% endif %}
|
||||||
|
</span>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ol>
|
||||||
|
<nav class="pagination" aria-label="Legislator results pages">
|
||||||
|
{% if previous_page %}
|
||||||
|
<a href="{{ build_legislator_search_url(q=q, per_page=per_page, page=previous_page) }}">Previous</a>
|
||||||
|
{% else %}
|
||||||
|
<span>Previous</span>
|
||||||
|
{% endif %}
|
||||||
|
<strong>Page {{ page }} of {{ total_pages }}</strong>
|
||||||
|
{% if next_page %}
|
||||||
|
<a href="{{ build_legislator_search_url(q=q, per_page=per_page, page=next_page) }}">Next</a>
|
||||||
|
{% else %}
|
||||||
|
<span>Next</span>
|
||||||
|
{% endif %}
|
||||||
|
</nav>
|
||||||
|
{% else %}
|
||||||
|
<p class="empty-state">No legislators matched this search.</p>
|
||||||
|
{% endif %}
|
||||||
|
</section>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if profile %}
|
||||||
|
<section class="profile-card">
|
||||||
|
<header class="profile-header">
|
||||||
|
<div class="profile-identity">
|
||||||
|
<span class="avatar">{{ profile.legislator.display_name.split(',')[0][:1] }}{{ profile.legislator.display_name.split(',')[-1].strip()[:1] }}</span>
|
||||||
|
<div>
|
||||||
|
<h2>{{ profile.legislator.display_name }} <span class="party-pill">{{ profile.legislator.chamber or "LEG" }}</span></h2>
|
||||||
|
<p>{{ profile.legislator.state or "US" }} · {{ profile.legislator.party or "Unaffiliated" }}{% if profile.serving_since %} · Serving since {{ profile.serving_since }}{% endif %}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="overall-score">
|
||||||
|
<span>Overall avg</span>
|
||||||
|
<strong>{{ profile.overall_score|round(0) if profile.overall_score is not none else "—" }}</strong>
|
||||||
|
<small>/ 100</small>
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if profile.top_topics or profile.bottom_topics %}
|
||||||
|
<div class="topic-panels">
|
||||||
|
<article>
|
||||||
|
<h3>Most important issues for</h3>
|
||||||
|
{% for item in profile.top_topics %}
|
||||||
|
<a class="topic-row" href="{{ build_legislator_url(legislator_id=profile.legislator.legislator_id, topic=item.topic) }}">
|
||||||
|
<strong class="score positive">{{ item.score|round(0) }}</strong>
|
||||||
|
<span>{{ item.topic }}</span>
|
||||||
|
<i style="width: {{ item.score }}%"></i>
|
||||||
|
</a>
|
||||||
|
{% endfor %}
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<h3 class="opposed-heading">Most important issues against</h3>
|
||||||
|
{% for item in profile.bottom_topics %}
|
||||||
|
<a class="topic-row {{ 'active' if item.topic == selected_topic else '' }}" href="{{ build_legislator_url(legislator_id=profile.legislator.legislator_id, topic=item.topic) }}">
|
||||||
|
<strong class="score negative">{{ item.score|round(0) }}</strong>
|
||||||
|
<span>{{ item.topic }}</span>
|
||||||
|
<i style="width: {{ item.score }}%"></i>
|
||||||
|
</a>
|
||||||
|
{% endfor %}
|
||||||
|
</article>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<section class="profile-history">
|
||||||
|
<h3>{{ selected_topic or "Topic" }} — score history</h3>
|
||||||
|
<div class="chart-frame">{{ history_svg | safe }}</div>
|
||||||
|
{% if history_series %}
|
||||||
|
<div class="chart-legend compact" aria-label="Chart legend">
|
||||||
|
{% for item in history_series %}
|
||||||
|
<div class="chart-legend-row">
|
||||||
|
<span class="chart-legend-line line-0"></span>
|
||||||
|
<span class="chart-legend-marker marker-0"></span>
|
||||||
|
<div class="chart-legend-copy">
|
||||||
|
<span class="chart-legend-label">{{ item.label }}</span>
|
||||||
|
<span class="chart-legend-meta">
|
||||||
|
{% if item.party %}{{ item.party }}{% endif %}{% if item.party and item.state %} · {% endif %}{% if item.state %}{{ item.state }}{% endif %}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</section>
|
||||||
|
{% else %}
|
||||||
|
<p class="empty-state">No issue scores are available for this legislator yet.</p>
|
||||||
|
{% endif %}
|
||||||
|
</section>
|
||||||
|
{% endif %}
|
||||||
|
</main>
|
||||||
|
<footer class="footer">
|
||||||
|
<span>Actual record, not rhetoric</span>
|
||||||
|
<span>Source: congressional roll-call votes</span>
|
||||||
|
<span>Not affiliated with any political party or organization</span>
|
||||||
|
</footer>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Sign in | Nornsight{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="login-shell">
|
||||||
|
<section class="login-panel" aria-labelledby="login-title">
|
||||||
|
<div class="login-copy">
|
||||||
|
<p class="eyebrow">Admin access</p>
|
||||||
|
<h1 id="login-title">Sign in to Nornsight</h1>
|
||||||
|
<p>Use the dashboard account to review rankings, profiles, and legislator comparisons.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<form class="login-form" action="/login" method="post">
|
||||||
|
<input type="hidden" name="next" value="{{ next_path }}">
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p class="form-error" role="alert">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<label for="username">Username</label>
|
||||||
|
<input
|
||||||
|
id="username"
|
||||||
|
name="username"
|
||||||
|
type="text"
|
||||||
|
autocomplete="username"
|
||||||
|
value="{{ username }}"
|
||||||
|
required
|
||||||
|
autofocus
|
||||||
|
>
|
||||||
|
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input
|
||||||
|
id="password"
|
||||||
|
name="password"
|
||||||
|
type="password"
|
||||||
|
autocomplete="current-password"
|
||||||
|
required
|
||||||
|
>
|
||||||
|
|
||||||
|
<button type="submit">Sign in</button>
|
||||||
|
</form>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
<section class="chart-card">
|
||||||
|
<header>
|
||||||
|
<h2>Score history{% if selected_issue_label %} — {{ selected_issue_label }}{% endif %}</h2>
|
||||||
|
<a href="{{ build_url(request, compare=[]) }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, compare=[]) }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, compare=[]) }}">Clear comparison</a>
|
||||||
|
</header>
|
||||||
|
<div class="chart-frame">
|
||||||
|
{{ chart_svg | safe }}
|
||||||
|
</div>
|
||||||
|
{% if chart_series %}
|
||||||
|
<div class="chart-legend" aria-label="Chart legend">
|
||||||
|
{% for item in chart_series %}
|
||||||
|
{% set style_index = loop.index0 % 4 %}
|
||||||
|
<div class="chart-legend-row">
|
||||||
|
<span class="chart-legend-line line-{{ style_index }}"></span>
|
||||||
|
<span class="chart-legend-marker marker-{{ style_index }}"></span>
|
||||||
|
<div class="chart-legend-copy">
|
||||||
|
<span class="chart-legend-label">{{ item.label }}</span>
|
||||||
|
<span class="chart-legend-meta">
|
||||||
|
{% if item.party %}{{ item.party }}{% endif %}{% if item.party and item.state %} · {% endif %}{% if item.state %}{{ item.state }}{% endif %}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<p class="score-note">Scores reflect averaged precomputed topic rows per year. Sparse years are omitted rather than plotted as zero.</p>
|
||||||
|
</section>
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
<section class="controls-grid">
|
||||||
|
{% include "partials/_issue_filters.html" %}
|
||||||
|
<div class="chamber-card">
|
||||||
|
<a class="segment {{ 'active' if chamber == 'house' else '' }}"
|
||||||
|
href="{{ build_url(request, chamber='house') }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, chamber='house') }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, chamber='house') }}">House</a>
|
||||||
|
<a class="segment {{ 'active' if chamber == 'senate' else '' }}"
|
||||||
|
href="{{ build_url(request, chamber='senate') }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, chamber='senate') }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, chamber='senate') }}">Senate</a>
|
||||||
|
<a class="segment {{ 'active' if chamber == 'all' else '' }}"
|
||||||
|
href="{{ build_url(request, chamber='all') }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, chamber='all') }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, chamber='all') }}">All</a>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<p class="score-note">Support score: 1-100 precomputed from bill topic stance and roll-call votes. Higher means more aligned with the topic.</p>
|
||||||
|
|
||||||
|
{% include "partials/_rankings.html" %}
|
||||||
|
{% include "partials/_chart.html" %}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
<section class="filter-card">
|
||||||
|
<h2>Issue filters</h2>
|
||||||
|
<form class="issue-form"
|
||||||
|
method="get"
|
||||||
|
action="/dashboard"
|
||||||
|
hx-get="/partials/dashboard"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="/dashboard">
|
||||||
|
<input type="hidden" name="chamber" value="{{ chamber }}">
|
||||||
|
{% if congress %}
|
||||||
|
<input type="hidden" name="congress" value="{{ congress }}">
|
||||||
|
{% endif %}
|
||||||
|
{% for legislator_id in compare %}
|
||||||
|
<input type="hidden" name="compare" value="{{ legislator_id }}">
|
||||||
|
{% endfor %}
|
||||||
|
{% for issue in issues %}
|
||||||
|
<span class="chip">
|
||||||
|
{{ issue }}
|
||||||
|
<a href="{{ build_url(request, issues=issues[:loop.index0] + issues[loop.index:]) }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, issues=issues[:loop.index0] + issues[loop.index:]) }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, issues=issues[:loop.index0] + issues[loop.index:]) }}"
|
||||||
|
aria-label="Remove {{ issue }}">×</a>
|
||||||
|
</span>
|
||||||
|
<input type="hidden" name="issues" value="{{ issue }}">
|
||||||
|
{% endfor %}
|
||||||
|
<label class="search-box">
|
||||||
|
<span class="sr-only">Search issue areas</span>
|
||||||
|
<input type="search" name="issues" placeholder="Search issue areas" autocomplete="off">
|
||||||
|
</label>
|
||||||
|
<button type="submit">Apply</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% if suggestions %}
|
||||||
|
<div class="suggestions" aria-label="Issue suggestions">
|
||||||
|
{% for suggestion in suggestions %}
|
||||||
|
{% if suggestion not in issues %}
|
||||||
|
<a href="{{ build_url(request, issues=issues + [suggestion]) }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, issues=issues + [suggestion]) }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, issues=issues + [suggestion]) }}">{{ suggestion }}</a>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</section>
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
{% if matches %}
|
||||||
|
<div class="result-chips" aria-label="Search suggestions">
|
||||||
|
{% for option in matches %}
|
||||||
|
<a href="{{ build_legislator_url(legislator_id=option.legislator_id) }}">
|
||||||
|
{{ option.display_name }}{% if option.state %} · {{ option.state }}{% endif %}
|
||||||
|
</a>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% elif q %}
|
||||||
|
<p class="suggestion-empty">No matches</p>
|
||||||
|
{% endif %}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
<section class="rankings-grid">
|
||||||
|
<article class="ranking-card">
|
||||||
|
<header>
|
||||||
|
<h2>Most supportive</h2>
|
||||||
|
<span>Top 10</span>
|
||||||
|
</header>
|
||||||
|
{% if rankings.supportive %}
|
||||||
|
<ol class="ranking-list">
|
||||||
|
{% for row in rankings.supportive %}
|
||||||
|
{% set next_compare = toggle_compare(compare, row.legislator_id) %}
|
||||||
|
<li class="{{ 'selected' if row.legislator_id in compare else '' }}">
|
||||||
|
<a href="{{ build_url(request, compare=next_compare) }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, compare=next_compare) }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, compare=next_compare) }}">
|
||||||
|
<span class="rank">{{ loop.index }}</span>
|
||||||
|
<strong class="score positive">{{ row.score|round(1) }}</strong>
|
||||||
|
<span class="member">
|
||||||
|
<strong>{{ row.display_name }}</strong>
|
||||||
|
<small>{{ row.state or "US" }}{% if row.party %} · {{ row.party[:1] }}{% endif %}</small>
|
||||||
|
</span>
|
||||||
|
<span class="votes">{{ row.total }} rows</span>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ol>
|
||||||
|
{% else %}
|
||||||
|
<p class="empty-state">{{ empty_message }}</p>
|
||||||
|
{% endif %}
|
||||||
|
</article>
|
||||||
|
|
||||||
|
<article class="ranking-card">
|
||||||
|
<header>
|
||||||
|
<h2>Most opposed</h2>
|
||||||
|
<span>Bottom 10</span>
|
||||||
|
</header>
|
||||||
|
{% if rankings.opposed %}
|
||||||
|
<ol class="ranking-list">
|
||||||
|
{% for row in rankings.opposed %}
|
||||||
|
{% set next_compare = toggle_compare(compare, row.legislator_id) %}
|
||||||
|
<li class="{{ 'selected' if row.legislator_id in compare else '' }}">
|
||||||
|
<a href="{{ build_url(request, compare=next_compare) }}"
|
||||||
|
hx-get="{{ build_dashboard_partial_url(request, compare=next_compare) }}"
|
||||||
|
hx-target="#dashboard-body"
|
||||||
|
hx-push-url="{{ build_url(request, compare=next_compare) }}">
|
||||||
|
<span class="rank">{{ loop.index }}</span>
|
||||||
|
<strong class="score negative">{{ row.score|round(1) }}</strong>
|
||||||
|
<span class="member">
|
||||||
|
<strong>{{ row.display_name }}</strong>
|
||||||
|
<small>{{ row.state or "US" }}{% if row.party %} · {{ row.party[:1] }}{% endif %}</small>
|
||||||
|
</span>
|
||||||
|
<span class="votes">{{ row.total }} rows</span>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ol>
|
||||||
|
{% else %}
|
||||||
|
<p class="empty-state">{{ empty_message }}</p>
|
||||||
|
{% endif %}
|
||||||
|
</article>
|
||||||
|
</section>
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Database Setup Required{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<main class="shell">
|
||||||
|
<section class="page-heading stacked-heading">
|
||||||
|
<div>
|
||||||
|
<h1>Database setup required</h1>
|
||||||
|
<p>Configure DATA_SCIENCE_DEV before opening the dashboard.</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
<pre class="setup-error">{{ error }}</pre>
|
||||||
|
</main>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
[project]
|
||||||
|
name = "ds-testing-pipelines"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Data science pipeline tools and legislative dashboard."
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"alembic",
|
||||||
|
"fastapi",
|
||||||
|
"httpx",
|
||||||
|
"jinja2",
|
||||||
|
"psycopg",
|
||||||
|
"sqlalchemy",
|
||||||
|
"typer",
|
||||||
|
"uvicorn[standard]",
|
||||||
|
"workos",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["pipelines*"]
|
||||||
@@ -0,0 +1,370 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from pipelines.web import auth, main
|
||||||
|
from pipelines.web.repository import (
|
||||||
|
ChartSeries,
|
||||||
|
LegislatorOption,
|
||||||
|
RadarSeries,
|
||||||
|
RankingResult,
|
||||||
|
RankingRow,
|
||||||
|
TimePoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_healthz() -> None:
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/healthz")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.text == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_home_page_renders() -> None:
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Invite-only access" in response.text
|
||||||
|
assert "Sign in" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_dashboard_redirects_to_login() -> None:
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/dashboard?issues=Health", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"].endswith(
|
||||||
|
"/login?next=%2Fdashboard%3Fissues%3DHealth"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_other_protected_routes_redirect_when_unauthenticated() -> None:
|
||||||
|
client = TestClient(main.app)
|
||||||
|
for path in ["/legislators", "/compare", "/admin"]:
|
||||||
|
response = client.get(path, follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"].endswith(f"/login?next={path.replace('/', '%2F', 1)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_redirects_to_workos(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.auth,
|
||||||
|
"build_authorization_url",
|
||||||
|
lambda next_path: f"https://auth.example/login?state={next_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/login?next=/compare", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "https://auth.example/login?state=/compare"
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_redirects_authenticated_user(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: _viewer_session())
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/login?next=/compare", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "/compare"
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_sets_session_cookie(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.auth,
|
||||||
|
"exchange_code",
|
||||||
|
lambda request: auth.CallbackResult(
|
||||||
|
sealed_session="sealed-session-value", next_path="/dashboard"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(main.auth, "get_auth_config", _fake_auth_config)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/callback?code=abc&state=/dashboard", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "/dashboard"
|
||||||
|
assert "workos_session=sealed-session-value" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_failure_redirects_home_and_clears_cookie(monkeypatch) -> None:
|
||||||
|
def raise_exchange_error(request):
|
||||||
|
raise RuntimeError("bad code")
|
||||||
|
|
||||||
|
monkeypatch.setattr(main.auth, "exchange_code", raise_exchange_error)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/callback?code=bad", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "/?auth_error=1"
|
||||||
|
assert "workos_session=" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_logout_redirects_to_workos_and_clears_cookie(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.auth,
|
||||||
|
"get_logout_url",
|
||||||
|
lambda request: "https://auth.example/logout",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.post("/logout", follow_redirects=False)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "https://auth.example/logout"
|
||||||
|
assert "workos_session=" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_logout_with_invalid_session_cookie_clears_cookie(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_auth_config", _fake_auth_config)
|
||||||
|
monkeypatch.setattr(main.auth, "get_workos_client", _invalid_workos_client)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
client.cookies.set("workos_session", "bad-session-cookie")
|
||||||
|
response = client.post("/logout", follow_redirects=False)
|
||||||
|
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == "http://localhost:8000/"
|
||||||
|
assert "workos_session=" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_session_cookie_is_treated_as_unauthenticated(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_auth_config", _fake_auth_config)
|
||||||
|
monkeypatch.setattr(main.auth, "get_workos_client", _invalid_workos_client)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
client.cookies.set("workos_session", "bad-session-cookie")
|
||||||
|
response = client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Sign in" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_dashboard_route_renders_with_stubbed_repository(monkeypatch) -> None:
|
||||||
|
_patch_authenticated_dashboard(monkeypatch, current_user=_viewer_session())
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/dashboard?issues=Health&chamber=senate")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Legislative accountability" in response.text
|
||||||
|
assert "Most supportive" in response.text
|
||||||
|
assert "viewer@nornsight.test" in response.text
|
||||||
|
assert "/admin" not in response.text
|
||||||
|
assert '/partials/dashboard?issues=Health&chamber=house' in response.text
|
||||||
|
assert "/partials/dashboarddashboard?" not in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_route_forbids_viewer(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: _viewer_session())
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/admin")
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Admin access required."
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_route_renders_for_admin(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: _admin_session())
|
||||||
|
monkeypatch.setattr(main.auth, "get_auth_config", _fake_auth_config)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/admin")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Admin settings" in response.text
|
||||||
|
assert "admin@nornsight.test" in response.text
|
||||||
|
assert "org_test_123" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_page_renders_for_authenticated_user(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: _viewer_session())
|
||||||
|
_patch_compare_page_data(monkeypatch)
|
||||||
|
|
||||||
|
client = TestClient(main.app)
|
||||||
|
response = client.get("/compare")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Compare legislators" in response.text
|
||||||
|
assert "Sanders, B." in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def _viewer_session() -> auth.AuthSession:
|
||||||
|
return auth.AuthSession(
|
||||||
|
user_id="user_viewer",
|
||||||
|
email="viewer@nornsight.test",
|
||||||
|
first_name="Viewer",
|
||||||
|
last_name="User",
|
||||||
|
role_slugs={"viewer"},
|
||||||
|
organization_id="org_test_123",
|
||||||
|
raw_user=None,
|
||||||
|
raw_session=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_session() -> auth.AuthSession:
|
||||||
|
return auth.AuthSession(
|
||||||
|
user_id="user_admin",
|
||||||
|
email="admin@nornsight.test",
|
||||||
|
first_name="Admin",
|
||||||
|
last_name="User",
|
||||||
|
role_slugs={"admin", "viewer"},
|
||||||
|
organization_id="org_test_123",
|
||||||
|
raw_user=None,
|
||||||
|
raw_session=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_auth_config() -> auth.AuthConfig:
|
||||||
|
return auth.AuthConfig(
|
||||||
|
api_key="sk_test",
|
||||||
|
client_id="client_test",
|
||||||
|
cookie_password="x" * 32,
|
||||||
|
redirect_uri="http://localhost:8000/callback",
|
||||||
|
logout_redirect_uri="http://localhost:8000/",
|
||||||
|
session_cookie_name="workos_session",
|
||||||
|
organization_id="org_test_123",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _invalid_workos_client():
|
||||||
|
class InvalidSession:
|
||||||
|
def authenticate(self):
|
||||||
|
raise ValueError("invalid session")
|
||||||
|
|
||||||
|
def get_logout_url(self, *, return_to: str) -> str:
|
||||||
|
raise ValueError("invalid session")
|
||||||
|
|
||||||
|
class DummyUserManagement:
|
||||||
|
def load_sealed_session(self, *, session_data: str, cookie_password: str):
|
||||||
|
return InvalidSession()
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
user_management = DummyUserManagement()
|
||||||
|
|
||||||
|
return DummyClient()
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_authenticated_dashboard(monkeypatch, *, current_user: auth.AuthSession) -> None:
|
||||||
|
monkeypatch.setattr(main.auth, "get_current_session", lambda request: current_user)
|
||||||
|
|
||||||
|
class DummySession:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyScope:
|
||||||
|
def __enter__(self):
|
||||||
|
return DummySession()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
rankings = RankingResult(
|
||||||
|
supportive=[
|
||||||
|
RankingRow(
|
||||||
|
legislator_id=1,
|
||||||
|
display_name="Sanders, B.",
|
||||||
|
party="I",
|
||||||
|
state="VT",
|
||||||
|
chamber="senate",
|
||||||
|
score=78.0,
|
||||||
|
supportive=7,
|
||||||
|
opposed=2,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
opposed=[
|
||||||
|
RankingRow(
|
||||||
|
legislator_id=2,
|
||||||
|
display_name="Cruz, T.",
|
||||||
|
party="R",
|
||||||
|
state="TX",
|
||||||
|
chamber="senate",
|
||||||
|
score=22.0,
|
||||||
|
supportive=2,
|
||||||
|
opposed=7,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
history = [
|
||||||
|
ChartSeries(
|
||||||
|
legislator_id=1,
|
||||||
|
label="Sanders, B.",
|
||||||
|
party="I",
|
||||||
|
state="VT",
|
||||||
|
points=[TimePoint(year=2024, score=74.0), TimePoint(year=2025, score=78.0)],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "session_scope", lambda: DummyScope())
|
||||||
|
monkeypatch.setattr(main.repository, "latest_congress", lambda session: 119)
|
||||||
|
monkeypatch.setattr(main.repository, "has_scores", lambda session: True)
|
||||||
|
monkeypatch.setattr(main.repository, "latest_score_year", lambda session: 2026)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository, "latest_vote_date", lambda session, congress: date(2026, 1, 15)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"issue_suggestions",
|
||||||
|
lambda session, congress=None, limit=12: ["Health", "Taxation"],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"get_rankings",
|
||||||
|
lambda session, *, issues, chamber, congress: rankings,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"get_score_history",
|
||||||
|
lambda session, *, issues, chamber, congress, legislator_ids: history,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_compare_page_data(monkeypatch) -> None:
|
||||||
|
class DummySession:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyScope:
|
||||||
|
def __enter__(self):
|
||||||
|
return DummySession()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
legislator = LegislatorOption(
|
||||||
|
legislator_id=1,
|
||||||
|
display_name="Sanders, B.",
|
||||||
|
party="I",
|
||||||
|
state="VT",
|
||||||
|
chamber="senate",
|
||||||
|
)
|
||||||
|
topics = ["Health", "Taxation", "Energy"]
|
||||||
|
series = [
|
||||||
|
RadarSeries(
|
||||||
|
legislator=legislator,
|
||||||
|
average_score=77.0,
|
||||||
|
scores_by_topic={"Health": 82.0, "Taxation": 71.0, "Energy": 78.0},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "session_scope", lambda: DummyScope())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"get_compare_defaults",
|
||||||
|
lambda session: ([1], topics),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"get_legislator_options",
|
||||||
|
lambda session, selected_legislators: [legislator],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"get_compare_radar_series",
|
||||||
|
lambda session, *, legislator_ids, topics: series,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"search_legislators",
|
||||||
|
lambda session, query=None, limit=12: [legislator],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main.repository,
|
||||||
|
"issue_suggestions",
|
||||||
|
lambda session, congress=None, limit=12: topics,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user