allowing multiple summaries per bill text
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Annotated, Any
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import Select, exists, or_, select
|
||||
from sqlalchemy import Select, exists, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from tiktoken import get_encoding
|
||||
@@ -20,6 +20,7 @@ from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
BillTextSummary,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
@@ -112,7 +113,7 @@ def summarize_bill_text(
|
||||
model: str,
|
||||
bill_text: BillText,
|
||||
summarization_prompts: dict[str, str],
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""Generate and return a summary for one bill_text row."""
|
||||
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||
bill_text=bill_text,
|
||||
@@ -136,15 +137,21 @@ def summarize_bill_text(
|
||||
|
||||
def store_bill_summary_result(
|
||||
*,
|
||||
session: Session,
|
||||
bill_text: BillText,
|
||||
summary: str,
|
||||
model: str,
|
||||
) -> None:
|
||||
) -> BillTextSummary:
|
||||
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||
bill_text.summary = summary
|
||||
bill_text.summarization_model = model
|
||||
bill_text.summarization_system_prompt_version = "v1.2"
|
||||
bill_text.summarization_user_prompt_version = "v1"
|
||||
summary_row = BillTextSummary(
|
||||
bill_text=bill_text,
|
||||
summary=summary,
|
||||
summarization_model=model,
|
||||
summarization_system_prompt_version="v1.2",
|
||||
summarization_user_prompt_version="v1",
|
||||
)
|
||||
session.add(summary_row)
|
||||
return summary_row
|
||||
|
||||
|
||||
def create_select_bill_texts_for_summarization(
|
||||
@@ -154,7 +161,7 @@ def create_select_bill_texts_for_summarization(
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select:
|
||||
) -> Select[tuple[BillText]]:
|
||||
"""Select bill_text rows that have source text and need summaries."""
|
||||
stmt = (
|
||||
select(BillText)
|
||||
@@ -189,7 +196,13 @@ def create_select_bill_texts_for_summarization(
|
||||
)
|
||||
)
|
||||
if not force:
|
||||
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||
stmt = stmt.where(
|
||||
~exists(
|
||||
select(BillTextSummary.id).where(
|
||||
BillTextSummary.bill_text_id == BillText.id
|
||||
)
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
@@ -287,6 +300,7 @@ def main(
|
||||
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||
continue
|
||||
store_bill_summary_result(
|
||||
session=session,
|
||||
bill_text=bill_text,
|
||||
summary=summary,
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user