5 Commits

12 changed files with 775 additions and 25 deletions
@@ -0,0 +1,211 @@
"""move bill text summaries into a child table.
Revision ID: 4b2e1c9d8f70
Revises: b9360b0b0c22
Create Date: 2026-05-03 00:00:00.000000
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from pipelines.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "4b2e1c9d8f70"
down_revision: str | None = "b9360b0b0c22"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
op.create_table(
"bill_text_summary",
sa.Column("bill_text_id", sa.Integer(), nullable=False),
sa.Column("summary", sa.String(), nullable=False),
sa.Column("summarization_model", sa.String(), nullable=True),
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["bill_text_id"],
[f"{schema}.bill_text.id"],
name=op.f("fk_bill_text_summary_bill_text_id_bill_text"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text_summary")),
schema=schema,
)
op.create_index(
"ix_bill_text_summary_bill_text_id",
"bill_text_summary",
["bill_text_id"],
unique=False,
schema=schema,
)
op.create_index(
"ix_bill_text_summary_bill_text_id_created",
"bill_text_summary",
["bill_text_id", "created"],
unique=False,
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("primary_summary_id", sa.Integer(), nullable=True),
schema=schema,
)
op.create_foreign_key(
op.f("fk_bill_text_primary_summary_id_bill_text_summary"),
"bill_text",
"bill_text_summary",
["primary_summary_id"],
["id"],
source_schema=schema,
referent_schema=schema,
ondelete="SET NULL",
)
op.execute(
sa.text(
f"""
INSERT INTO {schema}.bill_text_summary (
bill_text_id,
summary,
summarization_model,
summarization_user_prompt_version,
summarization_system_prompt_version,
created,
updated
)
SELECT
bill_text.id,
bill_text.summary,
bill_text.summarization_model,
bill_text.summarization_user_prompt_version,
bill_text.summarization_system_prompt_version,
COALESCE(bill_text.updated, bill_text.created, now()),
COALESCE(bill_text.updated, bill_text.created, now())
FROM {schema}.bill_text
WHERE bill_text.summary IS NOT NULL
AND btrim(bill_text.summary) <> ''
"""
)
)
op.drop_column("bill_text", "summary", schema=schema)
op.drop_column("bill_text", "summarization_model", schema=schema)
op.drop_column("bill_text", "summarization_user_prompt_version", schema=schema)
op.drop_column("bill_text", "summarization_system_prompt_version", schema=schema)
def downgrade() -> None:
"""Downgrade."""
op.add_column(
"bill_text",
sa.Column("summarization_system_prompt_version", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summarization_user_prompt_version", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summarization_model", sa.String(), nullable=True),
schema=schema,
)
op.add_column(
"bill_text",
sa.Column("summary", sa.String(), nullable=True),
schema=schema,
)
op.execute(
sa.text(
f"""
WITH ranked AS (
SELECT
bts.*,
row_number() OVER (
PARTITION BY bts.bill_text_id
ORDER BY bts.created DESC, bts.id DESC
) AS rn
FROM {schema}.bill_text_summary AS bts
),
chosen AS (
SELECT
bill_text.id AS bill_text_id,
COALESCE(ps.summary, ls.summary) AS summary,
COALESCE(
ps.summarization_model,
ls.summarization_model
) AS summarization_model,
COALESCE(
ps.summarization_user_prompt_version,
ls.summarization_user_prompt_version
) AS summarization_user_prompt_version,
COALESCE(
ps.summarization_system_prompt_version,
ls.summarization_system_prompt_version
) AS summarization_system_prompt_version
FROM {schema}.bill_text
LEFT JOIN {schema}.bill_text_summary AS ps
ON ps.id = bill_text.primary_summary_id
LEFT JOIN ranked AS ls
ON ls.bill_text_id = bill_text.id
AND ls.rn = 1
)
UPDATE {schema}.bill_text
SET
summary = chosen.summary,
summarization_model = chosen.summarization_model,
summarization_user_prompt_version = chosen.summarization_user_prompt_version,
summarization_system_prompt_version = chosen.summarization_system_prompt_version
FROM chosen
WHERE chosen.bill_text_id = bill_text.id
"""
)
)
op.drop_constraint(
op.f("fk_bill_text_primary_summary_id_bill_text_summary"),
"bill_text",
schema=schema,
type_="foreignkey",
)
op.drop_column("bill_text", "primary_summary_id", schema=schema)
op.drop_index(
"ix_bill_text_summary_bill_text_id_created",
table_name="bill_text_summary",
schema=schema,
)
op.drop_index(
"ix_bill_text_summary_bill_text_id",
table_name="bill_text_summary",
schema=schema,
)
op.drop_table("bill_text_summary", schema=schema)
+116
View File
@@ -0,0 +1,116 @@
"""Nornsight — BERTopic POC Inference Script.
Loads the trained model and labels a small batch of posts,
writing results to main.post_topic for inspection.
POC: processes a single batch of 1k posts to validate the pipeline end-to-end.
"""
from __future__ import annotations
import logging
import time
from collections import Counter
from pathlib import Path
from bertopic import BERTopic
from sqlalchemy import Engine, func, insert, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import PostTopic, Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Run BERTopic inference against a sample of posts."""
configure_logger()
config = get_bertopic_infer_config()
run_inference(config)
logger.info(
"POC inference complete. Check main.post_topic in DBeaver to inspect results."
)
def run_inference(config: BertTopicInferConfig) -> None:
model_save_path = Path(config.model_save_path)
logger.info(f"Loading BERTopic model from {model_save_path}")
topic_model = BERTopic.load(str(model_save_path))
topic_info = topic_model.get_topic_info()
label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"]))
logger.info(f"Model loaded with {len(label_map)} topics")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
post_ids, texts = get_post_ids_and_test(engine, config)
logger.info(f"Fetched {len(texts)} posts")
logger.info("Running BERTopic transform")
start = time.perf_counter()
topics, _probabilities = topic_model.transform(texts)
elapsed = time.perf_counter() - start
logger.info(f"Transform complete in {elapsed:.1f}s")
# Write results to main.post_topic
records = [
{
"post_id": pid,
"topic_id": int(topic_id),
"topic_label": label_map.get(int(topic_id), "unknown"),
"model_version": config.model_version,
}
for pid, topic_id in zip(post_ids, topics)
]
with Session(engine) as session:
session.execute(insert(PostTopic), records)
session.commit()
count_topics(records)
logger.info(f"Wrote {len(records)} topic labels to main.post_topic")
def get_post_ids_and_test(
engine: Engine,
config: BertTopicInferConfig,
) -> None | tuple[list[int], list[str]]:
with Session(engine) as session:
logger.info(f"Fetching {config.poc_batch_size} posts for inference")
# Pull a fresh batch for inference — distinct from training sample
# using a fixed seed offset so we're not re-labeling training posts
stmt = select(Posts).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
)
if config.poc_batch_size > 0:
stmt = stmt.limit(config.poc_batch_size)
posts = session.scalars(stmt).all()
if not posts:
logger.warning("No posts were selected for inference")
return [], []
post_ids = [post.post_id for post in posts]
texts = [post.text.strip() for post in posts]
return post_ids, texts
def count_topics(records: list[dict]) -> None:
topic_counts = Counter(record.get("topic_label", "unknown") for record in records)
logger.info("Topic distribution in this batch:")
for label, count in topic_counts.most_common(10):
logger.info(" %s: %d", label, count)
if __name__ == "__main__":
main()
+119
View File
@@ -0,0 +1,119 @@
"""Nornsight — BERTopic POC Training Script.
Pulls a small stratified sample (~11.5k posts) from main.posts,
trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally.
POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts)
Full training rate will be: random() < 0.005 (~1.08M posts)
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
from bertopic import BERTopic
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Train and persist the BERTopic model."""
configure_logger()
config = get_bertopic_train_config()
docs = load_sample(config)
if not docs:
logger.warning("No training documents were selected")
return
train(docs, config)
logger.info(f"Done. Model saved as version {config.model_version}")
logger.info("Next: run infer.py to label a sample of posts in the database")
def load_sample(config: BertTopicTrainConfig) -> list[str]:
logger.info("Connecting to PostgreSQL via SQLAlchemy")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})")
start = time.perf_counter()
with Session(engine) as session:
texts = session.scalars(
select(Posts.text).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
func.random() < config.sample_rate,
)
).all()
elapsed = time.perf_counter() - start
logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s")
# Basic cleaning — strip whitespace and deduplicate
docs = list({text.strip() for text in texts})
logger.info(f"After cleaning and dedup: {len(docs)} posts")
return docs
def train(docs: list[str], config: BertTopicTrainConfig) -> None:
logger.info(
f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})"
)
cluster_model = MiniBatchKMeans(
n_clusters=config.n_topics,
random_state=42,
batch_size=1024,
n_init=3,
verbose=1,
)
topic_model = BERTopic(
hdbscan_model=cluster_model,
language="english",
calculate_probabilities=False, # saves memory
verbose=True,
)
logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)")
start = time.perf_counter()
topic_model.fit_transform(docs)
elapsed = time.perf_counter() - start
logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
# Log topic summary for quick inspection
topic_info = topic_model.get_topic_info()
logger.info(f"Topics found: {len(topic_info)}")
logger.info(f"\n{topic_info.to_string()}")
model_save_path = Path(config.model_save_path)
model_save_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving model to {model_save_path}")
topic_model.save(
str(model_save_path),
serialization="safetensors",
save_ctfidf=True,
save_embedding_model=True,
)
logger.info("Model saved")
if __name__ == "__main__":
main()
+57
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass
from os import getenv
from datetime import date
from pathlib import Path
import tomllib
@@ -50,6 +51,7 @@ class FinetuneConfig:
)
@dataclass
class BenchmarkConfig:
"""Top-level benchmark configuration loaded from TOML."""
@@ -101,6 +103,45 @@ class OpenAIConfig:
)
@dataclass
class BertTopicTrainConfig:
"""BERTopic training configuration loaded from TOML."""
sample_rate: float
min_text_length: int
n_topics: int
model_save_path: str
model_version: str | None = None
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicTrainConfig:
"""Load BERTopic training config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["train"]
today = date.today().isoformat()
if raw.get("model_version") is None:
raw["model_version"] = (
f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}"
)
return cls(**raw)
@dataclass
class BertTopicInferConfig:
"""BERTopic inference configuration loaded from TOML."""
min_text_length: int
poc_batch_size: int
model_version: str
model_save_path: str
@classmethod
def from_toml(cls, config_path: Path) -> BertTopicInferConfig:
"""Load BERTopic inference config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"]
return cls(**raw)
def get_config_dir() -> Path:
"""Get the path to the config directory."""
return Path(__file__).resolve().parents[2] / "config"
@@ -127,3 +168,19 @@ def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
if config_path is None:
config_path = default_config_path()
return BenchmarkConfig.from_toml(config_path)
def get_bertopic_train_config(
config_path: Path | None = None,
) -> BertTopicTrainConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicTrainConfig.from_toml(config_path)
def get_bertopic_infer_config(
config_path: Path | None = None,
) -> BertTopicInferConfig:
if config_path is None:
config_path = default_config_path()
return BertTopicInferConfig.from_toml(config_path)
+26 -12
View File
@@ -19,6 +19,7 @@ from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTextSummary,
BillTopic,
BillTopicPosition,
SubjectType,
@@ -72,11 +73,19 @@ class ExtractedBillTopic:
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
"""Pick one summarized bill_text row from the already-loaded relationship."""
for bill_text in bill.bill_texts:
if bill_text.summary and bill_text.summary.strip():
summary_row = bill_text.default_summary()
if summary_row and summary_row.summary.strip():
return bill_text
return None
def _bill_text_has_summary_clause() -> ColumnElement[bool]:
"""Return a correlated EXISTS clause for bill texts with at least one summary."""
return exists(
select(BillTextSummary.id).where(BillTextSummary.bill_text_id == BillText.id)
)
def normalize_topic_label(value: str) -> str:
"""Normalize a topic label for storage, comparison, and de-duping."""
normalized = value.strip().strip("\"'")
@@ -323,11 +332,7 @@ def create_select_bills_for_topic_extraction(
limit: int | None = None,
) -> Select[tuple[Bill]]:
"""Select bill rows that have summarized bill_text rows for topic extraction."""
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summarized_text_filters: list[ColumnElement[bool]] = [
BillText.bill_id == Bill.id,
*has_summary,
]
summarized_text_filters: list[ColumnElement[bool]] = [_bill_text_has_summary_clause()]
if with_votes_only:
summarized_text_filters.append(
exists(
@@ -347,11 +352,17 @@ def create_select_bills_for_topic_extraction(
)
)
)
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
summarized_text_exists = exists(
select(BillText.id).where(BillText.bill_id == Bill.id, *summarized_text_filters)
)
bill_text_loader = selectinload(Bill.bill_texts.and_(*summarized_text_filters))
stmt = (
select(Bill)
.where(summarized_text_exists)
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
.options(
bill_text_loader.selectinload(BillText.summaries),
bill_text_loader.selectinload(BillText.primary_summary),
)
.order_by(Bill.id)
)
if congress is not None:
@@ -363,7 +374,7 @@ def create_select_bills_for_topic_extraction(
select(BillText.id).where(
BillText.bill_id == Bill.id,
BillText.id.in_(bill_text_ids),
*summarized_text_filters[1:],
*summarized_text_filters,
)
)
stmt = stmt.where(selected_text_exists)
@@ -416,8 +427,7 @@ def collect_topic_extraction_diagnostics(
)
)
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summary_filters = [*bill_text_filters, *has_summary]
summary_filters = [*bill_text_filters, _bill_text_has_summary_clause()]
bills_with_summaries = session.scalar(
select(func.count(func.distinct(Bill.id)))
@@ -607,7 +617,11 @@ def main(
if bill_text is None:
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
continue
summary = bill_text.summary.strip()
summary_row = bill_text.default_summary()
if summary_row is None:
logger.warning("Skipping bill id=%s: no default summary", bill.id)
continue
summary = summary_row.summary.strip()
try:
extracted_topics = extract_topics_for_bill_text(
+23 -9
View File
@@ -9,7 +9,7 @@ from typing import Annotated, Any
import httpx
import typer
from sqlalchemy import Select, exists, or_, select
from sqlalchemy import Select, exists, select
from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding
@@ -20,6 +20,7 @@ from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTextSummary,
SubjectType,
VoteClassification,
VoteRelationship,
@@ -112,7 +113,7 @@ def summarize_bill_text(
model: str,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> str:
) -> str | None:
"""Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text,
@@ -136,15 +137,21 @@ def summarize_bill_text(
def store_bill_summary_result(
*,
session: Session,
bill_text: BillText,
summary: str,
model: str,
) -> None:
) -> BillTextSummary:
"""Store a generated summary and the prompt/model metadata that produced it."""
bill_text.summary = summary
bill_text.summarization_model = model
bill_text.summarization_system_prompt_version = "v1.2"
bill_text.summarization_user_prompt_version = "v1"
summary_row = BillTextSummary(
bill_text=bill_text,
summary=summary,
summarization_model=model,
summarization_system_prompt_version="v1.2",
summarization_user_prompt_version="v1",
)
session.add(summary_row)
return summary_row
def create_select_bill_texts_for_summarization(
@@ -154,7 +161,7 @@ def create_select_bill_texts_for_summarization(
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select:
) -> Select[tuple[BillText]]:
"""Select bill_text rows that have source text and need summaries."""
stmt = (
select(BillText)
@@ -189,7 +196,13 @@ def create_select_bill_texts_for_summarization(
)
)
if not force:
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
stmt = stmt.where(
~exists(
select(BillTextSummary.id).where(
BillTextSummary.bill_text_id == BillText.id
)
)
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
@@ -287,6 +300,7 @@ def main(
logger.warning("Skipping bill_text id=%s", bill_text.id)
continue
store_bill_summary_result(
session=session,
bill_text=bill_text,
summary=summary,
model=model,
@@ -6,6 +6,7 @@ from pipelines.orm.data_science_dev.congress.bill import (
BillActionRecordedVote,
BillRelation,
BillText,
BillTextSummary,
BillTopic,
BillTopicPosition,
)
@@ -54,6 +55,7 @@ __all__ = [
"BillActionRecordedVote",
"BillRelation",
"BillText",
"BillTextSummary",
"BillTopic",
"BillTopicPosition",
"ClassificationMethod",
@@ -105,13 +105,12 @@ class BillText(DataScienceDevTableBase):
)
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
primary_summary_id: Mapped[int | None] = mapped_column(
ForeignKey("main.bill_text_summary.id", ondelete="SET NULL")
)
version_code: Mapped[str]
version_name: Mapped[str | None]
text_content: Mapped[str | None]
summary: Mapped[str | None]
summarization_model: Mapped[str | None]
summarization_user_prompt_version: Mapped[str | None]
summarization_system_prompt_version: Mapped[str | None]
date: Mapped[date | None]
source_datetime_raw: Mapped[str | None]
text_url_xml: Mapped[str | None]
@@ -122,6 +121,57 @@ class BillText(DataScienceDevTableBase):
)
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
summaries: Mapped[list[BillTextSummary]] = relationship(
"BillTextSummary",
back_populates="bill_text",
cascade="all, delete-orphan",
foreign_keys="BillTextSummary.bill_text_id",
order_by=lambda: (
BillTextSummary.created.desc(),
BillTextSummary.id.desc(),
),
)
primary_summary: Mapped[BillTextSummary | None] = relationship(
"BillTextSummary",
foreign_keys=[primary_summary_id],
post_update=True,
)
def latest_summary(self) -> BillTextSummary | None:
"""Return the newest summary row for this bill text."""
return self.summaries[0] if self.summaries else None
def default_summary(self) -> BillTextSummary | None:
"""Return the primary summary when set, otherwise the newest summary."""
return self.primary_summary or self.latest_summary()
class BillTextSummary(DataScienceDevTableBase):
"""Stores one generated summary for a bill text version."""
__tablename__ = "bill_text_summary"
__table_args__ = (
Index("ix_bill_text_summary_bill_text_id", "bill_text_id"),
Index(
"ix_bill_text_summary_bill_text_id_created",
"bill_text_id",
"created",
),
)
bill_text_id: Mapped[int] = mapped_column(
ForeignKey("main.bill_text.id", ondelete="CASCADE")
)
summary: Mapped[str]
summarization_model: Mapped[str | None]
summarization_user_prompt_version: Mapped[str | None]
summarization_system_prompt_version: Mapped[str | None]
bill_text: Mapped[BillText] = relationship(
"BillText",
back_populates="summaries",
foreign_keys=[bill_text_id],
)
class BillAction(DataScienceDevTableBase):
+2
View File
@@ -11,6 +11,7 @@ from pipelines.orm.data_science_dev.congress import (
BillActionRecordedVote,
BillRelation,
BillText,
BillTextSummary,
BillTopic,
BillTopicPosition,
ClassificationMethod,
@@ -51,6 +52,7 @@ __all__ = [
"BillActionRecordedVote",
"BillRelation",
"BillText",
"BillTextSummary",
"BillTopic",
"BillTopicPosition",
"ClassificationMethod",
+36
View File
@@ -0,0 +1,36 @@
from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary
def test_default_summary_prefers_primary_summary() -> None:
primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary")
latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[latest_summary],
primary_summary=primary_summary,
)
assert bill_text.default_summary() is primary_summary
def test_default_summary_falls_back_to_latest_summary() -> None:
latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest")
older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[latest_summary, older_summary],
)
assert bill_text.latest_summary() is latest_summary
assert bill_text.default_summary() is latest_summary
def test_default_summary_is_none_without_summaries() -> None:
bill_text = BillText(id=10, bill_id=5, version_code="ih")
assert bill_text.latest_summary() is None
assert bill_text.default_summary() is None
+71
View File
@@ -0,0 +1,71 @@
from sqlalchemy.dialects import postgresql
from pipelines.jobs.extract_bill_topics import (
_select_bill_text_for_topic_extraction,
create_select_bills_for_topic_extraction,
)
from pipelines.orm.data_science_dev.congress import Bill, BillText, BillTextSummary
def _compile_sql(statement: object) -> str:
return str(
statement.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
def test_select_bill_text_for_topic_extraction_uses_primary_summary() -> None:
primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary")
newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[newest_summary],
primary_summary=primary_summary,
)
bill = Bill(
id=5,
congress=119,
bill_type="hr",
number=1,
bill_texts=[bill_text],
)
selected = _select_bill_text_for_topic_extraction(bill)
assert selected is bill_text
assert selected.default_summary() is primary_summary
def test_select_bill_text_for_topic_extraction_uses_latest_summary_without_primary() -> None:
newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest")
older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older")
bill_text = BillText(
id=10,
bill_id=5,
version_code="ih",
summaries=[newest_summary, older_summary],
)
bill = Bill(
id=5,
congress=119,
bill_type="hr",
number=1,
bill_texts=[bill_text],
)
selected = _select_bill_text_for_topic_extraction(bill)
assert selected is bill_text
assert selected.default_summary() is newest_summary
def test_create_select_bills_for_topic_extraction_uses_summary_exists_subquery() -> None:
sql = _compile_sql(create_select_bills_for_topic_extraction())
assert "bill_text_summary" in sql
assert "EXISTS" in sql
assert "bill_text.summary" not in sql
+58
View File
@@ -0,0 +1,58 @@
from sqlalchemy.dialects import postgresql
from pipelines.jobs.summarize_bills import (
create_select_bill_texts_for_summarization,
store_bill_summary_result,
)
from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary
class FakeSession:
def __init__(self) -> None:
self.added: list[object] = []
def add(self, value: object) -> None:
self.added.append(value)
def _compile_sql(statement: object) -> str:
return str(
statement.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
def test_store_bill_summary_result_creates_summary_row() -> None:
session = FakeSession()
bill_text = BillText(id=10, bill_id=5, version_code="ih")
summary_row = store_bill_summary_result(
session=session,
bill_text=bill_text,
summary="A summary",
model="gpt-5.4-mini",
)
assert session.added == [summary_row]
assert isinstance(summary_row, BillTextSummary)
assert summary_row.bill_text is bill_text
assert summary_row.summary == "A summary"
assert summary_row.summarization_model == "gpt-5.4-mini"
assert summary_row.summarization_system_prompt_version == "v1.2"
assert summary_row.summarization_user_prompt_version == "v1"
def test_create_select_bill_texts_for_summarization_excludes_existing_summaries() -> None:
sql = _compile_sql(create_select_bill_texts_for_summarization(force=False))
assert "bill_text_summary" in sql
assert "NOT (EXISTS" in sql or "NOT EXISTS" in sql
assert "bill_text.summary" not in sql
def test_create_select_bill_texts_for_summarization_force_skips_summary_filter() -> None:
sql = _compile_sql(create_select_bill_texts_for_summarization(force=True))
assert "bill_text_summary" not in sql