From d3fe6dba563fab380668d134ac170d2773e87aac Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Fri, 8 May 2026 18:30:07 -0400 Subject: [PATCH] allowing multiple summaries per bill text --- ...03-bill_text_multi_summary_4b2e1c9d8f70.py | 211 ++++++++++++++++++ pipelines/jobs/extract_bill_topics.py | 38 +++- pipelines/jobs/summarize_bills.py | 32 ++- .../orm/data_science_dev/congress/__init__.py | 2 + .../orm/data_science_dev/congress/bill.py | 58 ++++- pipelines/orm/data_science_dev/models.py | 2 + tests/test_bill_text_summary_model.py | 36 +++ tests/test_extract_bill_topics.py | 71 ++++++ tests/test_summarize_bills.py | 58 +++++ 9 files changed, 483 insertions(+), 25 deletions(-) create mode 100644 alembic/data_science_dev/versions/2026_05_03-bill_text_multi_summary_4b2e1c9d8f70.py create mode 100644 tests/test_bill_text_summary_model.py create mode 100644 tests/test_extract_bill_topics.py create mode 100644 tests/test_summarize_bills.py diff --git a/alembic/data_science_dev/versions/2026_05_03-bill_text_multi_summary_4b2e1c9d8f70.py b/alembic/data_science_dev/versions/2026_05_03-bill_text_multi_summary_4b2e1c9d8f70.py new file mode 100644 index 0000000..8f065a0 --- /dev/null +++ b/alembic/data_science_dev/versions/2026_05_03-bill_text_multi_summary_4b2e1c9d8f70.py @@ -0,0 +1,211 @@ +"""move bill text summaries into a child table. + +Revision ID: 4b2e1c9d8f70 +Revises: b9360b0b0c22 +Create Date: 2026-05-03 00:00:00.000000 + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlalchemy as sa + +from alembic import op +from pipelines.orm import DataScienceDevBase + +if TYPE_CHECKING: + from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "4b2e1c9d8f70" +down_revision: str | None = "b9360b0b0c22" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +schema = DataScienceDevBase.schema_name + + +def upgrade() -> None: + """Upgrade.""" + op.create_table( + "bill_text_summary", + sa.Column("bill_text_id", sa.Integer(), nullable=False), + sa.Column("summary", sa.String(), nullable=False), + sa.Column("summarization_model", sa.String(), nullable=True), + sa.Column("summarization_user_prompt_version", sa.String(), nullable=True), + sa.Column("summarization_system_prompt_version", sa.String(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["bill_text_id"], + [f"{schema}.bill_text.id"], + name=op.f("fk_bill_text_summary_bill_text_id_bill_text"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text_summary")), + schema=schema, + ) + op.create_index( + "ix_bill_text_summary_bill_text_id", + "bill_text_summary", + ["bill_text_id"], + unique=False, + schema=schema, + ) + op.create_index( + "ix_bill_text_summary_bill_text_id_created", + "bill_text_summary", + ["bill_text_id", "created"], + unique=False, + schema=schema, + ) + op.add_column( + "bill_text", + sa.Column("primary_summary_id", sa.Integer(), nullable=True), + schema=schema, + ) + op.create_foreign_key( + op.f("fk_bill_text_primary_summary_id_bill_text_summary"), + "bill_text", + "bill_text_summary", + ["primary_summary_id"], + ["id"], + source_schema=schema, + referent_schema=schema, + ondelete="SET NULL", + ) + + op.execute( + sa.text( + f""" + INSERT INTO {schema}.bill_text_summary ( + bill_text_id, + summary, + summarization_model, + summarization_user_prompt_version, + summarization_system_prompt_version, + created, + updated + ) + SELECT + bill_text.id, + bill_text.summary, + bill_text.summarization_model, + bill_text.summarization_user_prompt_version, + bill_text.summarization_system_prompt_version, + COALESCE(bill_text.updated, bill_text.created, now()), + COALESCE(bill_text.updated, bill_text.created, now()) + FROM {schema}.bill_text + WHERE bill_text.summary IS NOT NULL + AND btrim(bill_text.summary) <> '' + """ + ) + ) + + op.drop_column("bill_text", "summary", schema=schema) + op.drop_column("bill_text", "summarization_model", schema=schema) + op.drop_column("bill_text", "summarization_user_prompt_version", schema=schema) + op.drop_column("bill_text", "summarization_system_prompt_version", schema=schema) + + +def downgrade() -> None: + """Downgrade.""" + op.add_column( + "bill_text", + sa.Column("summarization_system_prompt_version", sa.String(), nullable=True), + schema=schema, + ) + op.add_column( + "bill_text", + sa.Column("summarization_user_prompt_version", sa.String(), nullable=True), + schema=schema, + ) + op.add_column( + "bill_text", + sa.Column("summarization_model", sa.String(), nullable=True), + schema=schema, + ) + op.add_column( + "bill_text", + sa.Column("summary", sa.String(), nullable=True), + schema=schema, + ) + + op.execute( + sa.text( + f""" + WITH ranked AS ( + SELECT + bts.*, + row_number() OVER ( + PARTITION BY bts.bill_text_id + ORDER BY bts.created DESC, bts.id DESC + ) AS rn + FROM {schema}.bill_text_summary AS bts + ), + chosen AS ( + SELECT + bill_text.id AS bill_text_id, + COALESCE(ps.summary, ls.summary) AS summary, + COALESCE( + ps.summarization_model, + ls.summarization_model + ) AS summarization_model, + COALESCE( + ps.summarization_user_prompt_version, + ls.summarization_user_prompt_version + ) AS summarization_user_prompt_version, + COALESCE( + ps.summarization_system_prompt_version, + ls.summarization_system_prompt_version + ) AS summarization_system_prompt_version + FROM {schema}.bill_text + LEFT JOIN {schema}.bill_text_summary AS ps + ON ps.id = bill_text.primary_summary_id + LEFT JOIN ranked AS ls + ON ls.bill_text_id = bill_text.id + AND ls.rn = 1 + ) + UPDATE {schema}.bill_text + SET + summary = chosen.summary, + summarization_model = chosen.summarization_model, + summarization_user_prompt_version = chosen.summarization_user_prompt_version, + summarization_system_prompt_version = chosen.summarization_system_prompt_version + FROM chosen + WHERE chosen.bill_text_id = bill_text.id + """ + ) + ) + + op.drop_constraint( + op.f("fk_bill_text_primary_summary_id_bill_text_summary"), + "bill_text", + schema=schema, + type_="foreignkey", + ) + op.drop_column("bill_text", "primary_summary_id", schema=schema) + op.drop_index( + "ix_bill_text_summary_bill_text_id_created", + table_name="bill_text_summary", + schema=schema, + ) + op.drop_index( + "ix_bill_text_summary_bill_text_id", + table_name="bill_text_summary", + schema=schema, + ) + op.drop_table("bill_text_summary", schema=schema) diff --git a/pipelines/jobs/extract_bill_topics.py b/pipelines/jobs/extract_bill_topics.py index c57562e..af05b23 100644 --- a/pipelines/jobs/extract_bill_topics.py +++ b/pipelines/jobs/extract_bill_topics.py @@ -19,6 +19,7 @@ from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.congress import ( Bill, BillText, + BillTextSummary, BillTopic, BillTopicPosition, SubjectType, @@ -72,11 +73,19 @@ class ExtractedBillTopic: def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None: """Pick one summarized bill_text row from the already-loaded relationship.""" for bill_text in bill.bill_texts: - if bill_text.summary and bill_text.summary.strip(): + summary_row = bill_text.default_summary() + if summary_row and summary_row.summary.strip(): return bill_text return None +def _bill_text_has_summary_clause() -> ColumnElement[bool]: + """Return a correlated EXISTS clause for bill texts with at least one summary.""" + return exists( + select(BillTextSummary.id).where(BillTextSummary.bill_text_id == BillText.id) + ) + + def normalize_topic_label(value: str) -> str: """Normalize a topic label for storage, comparison, and de-duping.""" normalized = value.strip().strip("\"'") @@ -323,11 +332,7 @@ def create_select_bills_for_topic_extraction( limit: int | None = None, ) -> Select[tuple[Bill]]: """Select bill rows that have summarized bill_text rows for topic extraction.""" - has_summary = (BillText.summary.is_not(None), BillText.summary != "") - summarized_text_filters: list[ColumnElement[bool]] = [ - BillText.bill_id == Bill.id, - *has_summary, - ] + summarized_text_filters: list[ColumnElement[bool]] = [_bill_text_has_summary_clause()] if with_votes_only: summarized_text_filters.append( exists( @@ -347,11 +352,17 @@ def create_select_bills_for_topic_extraction( ) ) ) - summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters)) + summarized_text_exists = exists( + select(BillText.id).where(BillText.bill_id == Bill.id, *summarized_text_filters) + ) + bill_text_loader = selectinload(Bill.bill_texts.and_(*summarized_text_filters)) stmt = ( select(Bill) .where(summarized_text_exists) - .options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:]))) + .options( + bill_text_loader.selectinload(BillText.summaries), + bill_text_loader.selectinload(BillText.primary_summary), + ) .order_by(Bill.id) ) if congress is not None: @@ -363,7 +374,7 @@ def create_select_bills_for_topic_extraction( select(BillText.id).where( BillText.bill_id == Bill.id, BillText.id.in_(bill_text_ids), - *summarized_text_filters[1:], + *summarized_text_filters, ) ) stmt = stmt.where(selected_text_exists) @@ -416,8 +427,7 @@ def collect_topic_extraction_diagnostics( ) ) - has_summary = (BillText.summary.is_not(None), BillText.summary != "") - summary_filters = [*bill_text_filters, *has_summary] + summary_filters = [*bill_text_filters, _bill_text_has_summary_clause()] bills_with_summaries = session.scalar( select(func.count(func.distinct(Bill.id))) @@ -607,7 +617,11 @@ def main( if bill_text is None: logger.warning("Skipping bill id=%s: no usable summary", bill.id) continue - summary = bill_text.summary.strip() + summary_row = bill_text.default_summary() + if summary_row is None: + logger.warning("Skipping bill id=%s: no default summary", bill.id) + continue + summary = summary_row.summary.strip() try: extracted_topics = extract_topics_for_bill_text( diff --git a/pipelines/jobs/summarize_bills.py b/pipelines/jobs/summarize_bills.py index 871d4a2..59bb1fb 100644 --- a/pipelines/jobs/summarize_bills.py +++ b/pipelines/jobs/summarize_bills.py @@ -9,7 +9,7 @@ from typing import Annotated, Any import httpx import typer -from sqlalchemy import Select, exists, or_, select +from sqlalchemy import Select, exists, select from sqlalchemy.orm import Session, selectinload from tiktoken import get_encoding @@ -20,6 +20,7 @@ from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.congress import ( Bill, BillText, + BillTextSummary, SubjectType, VoteClassification, VoteRelationship, @@ -112,7 +113,7 @@ def summarize_bill_text( model: str, bill_text: BillText, summarization_prompts: dict[str, str], -) -> str: +) -> str | None: """Generate and return a summary for one bill_text row.""" messages, user_prompt_tokens = build_bill_summary_messages( bill_text=bill_text, @@ -136,15 +137,21 @@ def summarize_bill_text( def store_bill_summary_result( *, + session: Session, bill_text: BillText, summary: str, model: str, -) -> None: +) -> BillTextSummary: """Store a generated summary and the prompt/model metadata that produced it.""" - bill_text.summary = summary - bill_text.summarization_model = model - bill_text.summarization_system_prompt_version = "v1.2" - bill_text.summarization_user_prompt_version = "v1" + summary_row = BillTextSummary( + bill_text=bill_text, + summary=summary, + summarization_model=model, + summarization_system_prompt_version="v1.2", + summarization_user_prompt_version="v1", + ) + session.add(summary_row) + return summary_row def create_select_bill_texts_for_summarization( @@ -154,7 +161,7 @@ def create_select_bill_texts_for_summarization( with_votes_only: bool = False, force: bool = False, limit: int | None = None, -) -> Select: +) -> Select[tuple[BillText]]: """Select bill_text rows that have source text and need summaries.""" stmt = ( select(BillText) @@ -189,7 +196,13 @@ def create_select_bill_texts_for_summarization( ) ) if not force: - stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == "")) + stmt = stmt.where( + ~exists( + select(BillTextSummary.id).where( + BillTextSummary.bill_text_id == BillText.id + ) + ) + ) if limit is not None: stmt = stmt.limit(limit) return stmt @@ -287,6 +300,7 @@ def main( logger.warning("Skipping bill_text id=%s", bill_text.id) continue store_bill_summary_result( + session=session, bill_text=bill_text, summary=summary, model=model, diff --git a/pipelines/orm/data_science_dev/congress/__init__.py b/pipelines/orm/data_science_dev/congress/__init__.py index a881b85..3baee05 100644 --- a/pipelines/orm/data_science_dev/congress/__init__.py +++ b/pipelines/orm/data_science_dev/congress/__init__.py @@ -6,6 +6,7 @@ from pipelines.orm.data_science_dev.congress.bill import ( BillActionRecordedVote, BillRelation, BillText, + BillTextSummary, BillTopic, BillTopicPosition, ) @@ -54,6 +55,7 @@ __all__ = [ "BillActionRecordedVote", "BillRelation", "BillText", + "BillTextSummary", "BillTopic", "BillTopicPosition", "ClassificationMethod", diff --git a/pipelines/orm/data_science_dev/congress/bill.py b/pipelines/orm/data_science_dev/congress/bill.py index c7ad0ad..0e0311a 100644 --- a/pipelines/orm/data_science_dev/congress/bill.py +++ b/pipelines/orm/data_science_dev/congress/bill.py @@ -105,13 +105,12 @@ class BillText(DataScienceDevTableBase): ) bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE")) + primary_summary_id: Mapped[int | None] = mapped_column( + ForeignKey("main.bill_text_summary.id", ondelete="SET NULL") + ) version_code: Mapped[str] version_name: Mapped[str | None] text_content: Mapped[str | None] - summary: Mapped[str | None] - summarization_model: Mapped[str | None] - summarization_user_prompt_version: Mapped[str | None] - summarization_system_prompt_version: Mapped[str | None] date: Mapped[date | None] source_datetime_raw: Mapped[str | None] text_url_xml: Mapped[str | None] @@ -122,6 +121,57 @@ class BillText(DataScienceDevTableBase): ) bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts") + summaries: Mapped[list[BillTextSummary]] = relationship( + "BillTextSummary", + back_populates="bill_text", + cascade="all, delete-orphan", + foreign_keys="BillTextSummary.bill_text_id", + order_by=lambda: ( + BillTextSummary.created.desc(), + BillTextSummary.id.desc(), + ), + ) + primary_summary: Mapped[BillTextSummary | None] = relationship( + "BillTextSummary", + foreign_keys=[primary_summary_id], + post_update=True, + ) + + def latest_summary(self) -> BillTextSummary | None: + """Return the newest summary row for this bill text.""" + return self.summaries[0] if self.summaries else None + + def default_summary(self) -> BillTextSummary | None: + """Return the primary summary when set, otherwise the newest summary.""" + return self.primary_summary or self.latest_summary() + + +class BillTextSummary(DataScienceDevTableBase): + """Stores one generated summary for a bill text version.""" + + __tablename__ = "bill_text_summary" + __table_args__ = ( + Index("ix_bill_text_summary_bill_text_id", "bill_text_id"), + Index( + "ix_bill_text_summary_bill_text_id_created", + "bill_text_id", + "created", + ), + ) + + bill_text_id: Mapped[int] = mapped_column( + ForeignKey("main.bill_text.id", ondelete="CASCADE") + ) + summary: Mapped[str] + summarization_model: Mapped[str | None] + summarization_user_prompt_version: Mapped[str | None] + summarization_system_prompt_version: Mapped[str | None] + + bill_text: Mapped[BillText] = relationship( + "BillText", + back_populates="summaries", + foreign_keys=[bill_text_id], + ) class BillAction(DataScienceDevTableBase): diff --git a/pipelines/orm/data_science_dev/models.py b/pipelines/orm/data_science_dev/models.py index bc2a214..e56a13b 100644 --- a/pipelines/orm/data_science_dev/models.py +++ b/pipelines/orm/data_science_dev/models.py @@ -11,6 +11,7 @@ from pipelines.orm.data_science_dev.congress import ( BillActionRecordedVote, BillRelation, BillText, + BillTextSummary, BillTopic, BillTopicPosition, ClassificationMethod, @@ -51,6 +52,7 @@ __all__ = [ "BillActionRecordedVote", "BillRelation", "BillText", + "BillTextSummary", "BillTopic", "BillTopicPosition", "ClassificationMethod", diff --git a/tests/test_bill_text_summary_model.py b/tests/test_bill_text_summary_model.py new file mode 100644 index 0000000..e2a9ba8 --- /dev/null +++ b/tests/test_bill_text_summary_model.py @@ -0,0 +1,36 @@ +from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary + + +def test_default_summary_prefers_primary_summary() -> None: + primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary") + latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest") + bill_text = BillText( + id=10, + bill_id=5, + version_code="ih", + summaries=[latest_summary], + primary_summary=primary_summary, + ) + + assert bill_text.default_summary() is primary_summary + + +def test_default_summary_falls_back_to_latest_summary() -> None: + latest_summary = BillTextSummary(id=2, bill_text_id=10, summary="latest") + older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older") + bill_text = BillText( + id=10, + bill_id=5, + version_code="ih", + summaries=[latest_summary, older_summary], + ) + + assert bill_text.latest_summary() is latest_summary + assert bill_text.default_summary() is latest_summary + + +def test_default_summary_is_none_without_summaries() -> None: + bill_text = BillText(id=10, bill_id=5, version_code="ih") + + assert bill_text.latest_summary() is None + assert bill_text.default_summary() is None diff --git a/tests/test_extract_bill_topics.py b/tests/test_extract_bill_topics.py new file mode 100644 index 0000000..b1f0a2a --- /dev/null +++ b/tests/test_extract_bill_topics.py @@ -0,0 +1,71 @@ +from sqlalchemy.dialects import postgresql + +from pipelines.jobs.extract_bill_topics import ( + _select_bill_text_for_topic_extraction, + create_select_bills_for_topic_extraction, +) +from pipelines.orm.data_science_dev.congress import Bill, BillText, BillTextSummary + + +def _compile_sql(statement: object) -> str: + return str( + statement.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + +def test_select_bill_text_for_topic_extraction_uses_primary_summary() -> None: + primary_summary = BillTextSummary(id=1, bill_text_id=10, summary="primary") + newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest") + bill_text = BillText( + id=10, + bill_id=5, + version_code="ih", + summaries=[newest_summary], + primary_summary=primary_summary, + ) + bill = Bill( + id=5, + congress=119, + bill_type="hr", + number=1, + bill_texts=[bill_text], + ) + + selected = _select_bill_text_for_topic_extraction(bill) + + assert selected is bill_text + assert selected.default_summary() is primary_summary + + +def test_select_bill_text_for_topic_extraction_uses_latest_summary_without_primary() -> None: + newest_summary = BillTextSummary(id=2, bill_text_id=10, summary="newest") + older_summary = BillTextSummary(id=1, bill_text_id=10, summary="older") + bill_text = BillText( + id=10, + bill_id=5, + version_code="ih", + summaries=[newest_summary, older_summary], + ) + bill = Bill( + id=5, + congress=119, + bill_type="hr", + number=1, + bill_texts=[bill_text], + ) + + selected = _select_bill_text_for_topic_extraction(bill) + + assert selected is bill_text + assert selected.default_summary() is newest_summary + + +def test_create_select_bills_for_topic_extraction_uses_summary_exists_subquery() -> None: + sql = _compile_sql(create_select_bills_for_topic_extraction()) + + assert "bill_text_summary" in sql + assert "EXISTS" in sql + assert "bill_text.summary" not in sql diff --git a/tests/test_summarize_bills.py b/tests/test_summarize_bills.py new file mode 100644 index 0000000..bf22e6c --- /dev/null +++ b/tests/test_summarize_bills.py @@ -0,0 +1,58 @@ +from sqlalchemy.dialects import postgresql + +from pipelines.jobs.summarize_bills import ( + create_select_bill_texts_for_summarization, + store_bill_summary_result, +) +from pipelines.orm.data_science_dev.congress import BillText, BillTextSummary + + +class FakeSession: + def __init__(self) -> None: + self.added: list[object] = [] + + def add(self, value: object) -> None: + self.added.append(value) + + +def _compile_sql(statement: object) -> str: + return str( + statement.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + +def test_store_bill_summary_result_creates_summary_row() -> None: + session = FakeSession() + bill_text = BillText(id=10, bill_id=5, version_code="ih") + + summary_row = store_bill_summary_result( + session=session, + bill_text=bill_text, + summary="A summary", + model="gpt-5.4-mini", + ) + + assert session.added == [summary_row] + assert isinstance(summary_row, BillTextSummary) + assert summary_row.bill_text is bill_text + assert summary_row.summary == "A summary" + assert summary_row.summarization_model == "gpt-5.4-mini" + assert summary_row.summarization_system_prompt_version == "v1.2" + assert summary_row.summarization_user_prompt_version == "v1" + + +def test_create_select_bill_texts_for_summarization_excludes_existing_summaries() -> None: + sql = _compile_sql(create_select_bill_texts_for_summarization(force=False)) + + assert "bill_text_summary" in sql + assert "NOT (EXISTS" in sql or "NOT EXISTS" in sql + assert "bill_text.summary" not in sql + + +def test_create_select_bill_texts_for_summarization_force_skips_summary_filter() -> None: + sql = _compile_sql(create_select_bill_texts_for_summarization(force=True)) + + assert "bill_text_summary" not in sql