From f33a5c22334e59a940ad104002ad53c23a068533 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Tue, 28 Apr 2026 22:44:29 -0400 Subject: [PATCH] adding calculate_legislator_scores.py summarize_bills.py and extract_bill_topics.py --- pipelines/__init__.py | 1 + pipelines/jobs/calculate_legislator_scores.py | 574 +++++++++++++++ pipelines/jobs/extract_bill_topics.py | 682 ++++++++++++++++++ pipelines/jobs/summarize_bills.py | 309 ++++++++ 4 files changed, 1566 insertions(+) create mode 100644 pipelines/__init__.py create mode 100644 pipelines/jobs/calculate_legislator_scores.py create mode 100644 pipelines/jobs/extract_bill_topics.py create mode 100644 pipelines/jobs/summarize_bills.py diff --git a/pipelines/__init__.py b/pipelines/__init__.py new file mode 100644 index 0000000..dc58a44 --- /dev/null +++ b/pipelines/__init__.py @@ -0,0 +1 @@ +"""Prompt benchmarking system for evaluating LLMs via vLLM.""" diff --git a/pipelines/jobs/calculate_legislator_scores.py b/pipelines/jobs/calculate_legislator_scores.py new file mode 100644 index 0000000..2a31361 --- /dev/null +++ b/pipelines/jobs/calculate_legislator_scores.py @@ -0,0 +1,574 @@ +"""Calculate legislator topic scores from bill topics and roll-call votes.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Annotated, Sequence + +import typer +from sqlalchemy import ( + ColumnElement, + Integer, + Select, + and_, + case, + cast, + delete, + extract, + func, + or_, + select, + tuple_, +) +from sqlalchemy.orm import Session + +from pipelines.congress_vote_context import create_score_run, finalize_score_run +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.congress import ( + BillTopic, + BillTopicPosition, + LegislatorScore, + SubjectType, + Vote, + VoteClassification, + VoteEffect, + VoteMeasureLink, + VoteMeasureRole, + VotePositionMeaning, + VoteRelationship, + VoteRecord, +) +from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label +from pipelines.web.scoring import ( + OPPOSE_POSITIONS, + SUPPORT_POSITIONS, + normalized_position_expression, +) + +logger = logging.getLogger(__name__) + +DELETE_BATCH_SIZE = 5_000 + + +@dataclass(frozen=True) +class ScoreDiagnostics: + """Counts for the input stages required to calculate legislator scores.""" + + bill_topic_rows: int + linked_vote_rows: int + vote_record_rows: int + topic_vote_links: int + scorable_vote_records: int + + +@dataclass(frozen=True) +class LegislatorScoreInput: + """One aggregated score ready to store in legislator_score.""" + + legislator_id: int + year: int + topic: str + score: float + supportive: int + opposed: int + + +def create_legislator_score_query( + *, + congress: int | None = None, + bill_ids: Sequence[int] | None = None, + topics: Sequence[str] | None = None, +) -> Select: + """Build the aggregate score query from extracted bill topics and vote records.""" + normalized_vote = normalized_position_expression(VoteRecord.position) + supportive_vote = _supportive_vote_expression(normalized_vote) + opposed_vote = _opposed_vote_expression(normalized_vote) + supportive_count = func.sum(supportive_vote) + opposed_count = func.sum(opposed_vote) + total_count = supportive_count + opposed_count + vote_year = cast(extract("year", Vote.vote_date), Integer) + score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score") + + stmt = ( + select( + VoteRecord.legislator_id.label("legislator_id"), + vote_year.label("year"), + BillTopic.topic.label("topic"), + score, + supportive_count.label("supportive"), + opposed_count.label("opposed"), + total_count.label("total"), + ) + .select_from(VoteRecord) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id) + .where( + *_eligible_vote_filters(), + _is_scorable_position(normalized_vote), + ) + .group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic) + .having(total_count > 0) + .order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic) + ) + if congress is not None: + stmt = stmt.where(Vote.congress == congress) + if bill_ids: + stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids))) + + normalized_topics = _normalize_topics(topics) + if normalized_topics: + stmt = stmt.where(BillTopic.topic.in_(normalized_topics)) + + return stmt + + +def collect_legislator_scores( + session: Session, + *, + congress: int | None = None, + bill_ids: Sequence[int] | None = None, + topics: Sequence[str] | None = None, +) -> list[LegislatorScoreInput]: + """Run the aggregate query and return score rows.""" + rows = session.execute( + create_legislator_score_query( + congress=congress, + bill_ids=bill_ids, + topics=topics, + ) + ) + return [ + LegislatorScoreInput( + legislator_id=int(row.legislator_id), + year=int(row.year), + topic=str(row.topic), + score=float(row.score), + supportive=int(row.supportive), + opposed=int(row.opposed), + ) + for row in rows + if row.score is not None + ] + + +def collect_score_diagnostics( + session: Session, + *, + congress: int | None = None, + bill_ids: Sequence[int] | None = None, + topics: Sequence[str] | None = None, +) -> ScoreDiagnostics: + """Count score pipeline inputs for explaining empty score runs.""" + normalized_topics = _normalize_topics(topics) + vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids) + topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics) + normalized_vote = normalized_position_expression(VoteRecord.position) + eligible_vote_filters = _eligible_vote_filters() + + bill_topic_rows = session.scalar( + select(func.count(BillTopic.id)).where(*topic_filters) + ) + linked_vote_rows = session.scalar( + select(func.count(func.distinct(Vote.id))) + .select_from(Vote) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .where(*vote_filters, *eligible_vote_filters) + ) + vote_record_rows = session.scalar( + select(func.count()) + .select_from(VoteRecord) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .where(*vote_filters, *eligible_vote_filters) + ) + topic_vote_links = session.scalar( + select(func.count()) + .select_from(VoteRecord) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id) + .where(*vote_filters, *topic_filters, *eligible_vote_filters) + ) + scorable_vote_records = session.scalar( + select(func.count()) + .select_from(VoteRecord) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id) + .where( + *vote_filters, + *topic_filters, + *eligible_vote_filters, + _is_scorable_position(normalized_vote), + ) + ) + + return ScoreDiagnostics( + bill_topic_rows=bill_topic_rows or 0, + linked_vote_rows=linked_vote_rows or 0, + vote_record_rows=vote_record_rows or 0, + topic_vote_links=topic_vote_links or 0, + scorable_vote_records=scorable_vote_records or 0, + ) + + +def store_legislator_scores( + session: Session, + rows: Sequence[LegislatorScoreInput], + *, + score_run_id: int | None, + replace_all: bool = False, +) -> int: + """Replace matching score rows and insert the newly calculated scores.""" + if replace_all: + session.execute(delete(LegislatorScore)) + elif rows: + keys = [ + (row.legislator_id, row.year, row.topic) + for row in rows + ] + for key_batch in _batched(keys, DELETE_BATCH_SIZE): + session.execute( + delete(LegislatorScore).where( + tuple_( + LegislatorScore.legislator_id, + LegislatorScore.year, + LegislatorScore.topic, + ).in_(key_batch) + ) + ) + + session.add_all( + [ + LegislatorScore( + legislator_id=row.legislator_id, + year=row.year, + topic=row.topic, + score=row.score, + score_run_id=score_run_id, + ) + for row in rows + ] + ) + return len(rows) + + +def _supportive_vote_expression( + normalized_vote: ColumnElement[str | None], +) -> ColumnElement[int]: + supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT) + opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT) + return case( + ( + and_( + BillTopic.support_position == BillTopicPosition.FOR, + supports_text, + ), + 1, + ), + ( + and_( + BillTopic.support_position == BillTopicPosition.AGAINST, + opposes_text, + ), + 1, + ), + else_=0, + ) + + +def _opposed_vote_expression( + normalized_vote: ColumnElement[str | None], +) -> ColumnElement[int]: + supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT) + opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT) + return case( + ( + and_( + BillTopic.support_position == BillTopicPosition.FOR, + opposes_text, + ), + 1, + ), + ( + and_( + BillTopic.support_position == BillTopicPosition.AGAINST, + supports_text, + ), + 1, + ), + else_=0, + ) + + +def _position_matches_effect( + normalized_vote: ColumnElement[str | None], + effect: VoteEffect, +) -> ColumnElement[bool]: + return or_( + and_( + normalized_vote.in_(sorted(SUPPORT_POSITIONS)), + VotePositionMeaning.yea_effect == effect, + ), + and_( + normalized_vote.in_(sorted(OPPOSE_POSITIONS)), + VotePositionMeaning.nay_effect == effect, + ), + and_( + normalized_vote == "present", + VotePositionMeaning.present_effect == effect, + ), + ) + + +def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]: + return or_( + _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT), + _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT), + ) + + +def _normalize_topics(topics: Sequence[str] | None) -> list[str]: + normalized: list[str] = [] + seen: set[str] = set() + for topic in topics or []: + value = normalize_topic_label(topic) + if value and value not in seen: + normalized.append(value) + seen.add(value) + return normalized + + +def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]: + return [ + items[index : index + batch_size] + for index in range(0, len(items), batch_size) + ] + + +def _vote_scope_filters( + *, + congress: int | None, + bill_ids: Sequence[int] | None, +) -> list[ColumnElement[bool]]: + filters: list[ColumnElement[bool]] = [] + if congress is not None: + filters.append(Vote.congress == congress) + if bill_ids: + filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids))) + return filters + + +def _topic_scope_filters( + *, + bill_ids: Sequence[int] | None, + topics: Sequence[str], +) -> list[ColumnElement[bool]]: + filters: list[ColumnElement[bool]] = [] + if bill_ids: + filters.append(BillTopic.bill_id.in_(list(bill_ids))) + if topics: + filters.append(BillTopic.topic.in_(list(topics))) + return filters + + +def _has_score_scope( + *, + congress: int | None, + bill_ids: Sequence[int] | None, + topics: Sequence[str] | None, +) -> bool: + return congress is not None or bool(bill_ids) or bool(topics) + + +def _eligible_vote_filters() -> list[ColumnElement[bool]]: + return [ + VoteClassification.subject_type == SubjectType.MEASURE, + VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE, + VoteClassification.is_direct_vote_on_legislative_text.is_(True), + VoteClassification.is_substantive_policy_vote.is_(True), + VoteClassification.is_special_rule.is_(False), + ] + + +def main( + congress: Annotated[ + int | None, + typer.Option(help="Only score votes from one Congress."), + ] = None, + bill_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-id", + help="Only score votes linked to one internal bill.id. Repeatable.", + ), + ] = None, + topics: Annotated[ + list[str] | None, + typer.Option("--topic", help="Only score one normalized topic. Repeatable."), + ] = None, + replace_all: Annotated[ + bool, + typer.Option( + help="Delete every existing legislator score before inserting. " + "Unfiltered runs do this automatically." + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option(help="Calculate scores without writing to the database."), + ] = False, + log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", + diagnose: Annotated[ + bool, + typer.Option(help="Log input-stage counts even when rows are calculated."), + ] = False, +) -> None: + """CLI entrypoint for calculating and storing legislator topic scores.""" + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + with Session(engine) as session: + rows = collect_legislator_scores( + session, + congress=congress, + bill_ids=bill_ids, + topics=topics, + ) + logger.info("Calculated %d legislator topic score rows", len(rows)) + if diagnose or not rows: + diagnostics = collect_score_diagnostics( + session, + congress=congress, + bill_ids=bill_ids, + topics=topics, + ) + _log_diagnostics(diagnostics) + + if dry_run: + session.rollback() + return + + score_run = create_score_run(session) + should_replace_all = replace_all or not _has_score_scope( + congress=congress, + bill_ids=bill_ids, + topics=topics, + ) + written = store_legislator_scores( + session, + rows, + score_run_id=score_run.id, + replace_all=should_replace_all, + ) + included_vote_count = session.scalar( + select(func.count(func.distinct(Vote.id))) + .select_from(VoteRecord) + .join(Vote, Vote.id == VoteRecord.vote_id) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id) + .where( + *_vote_scope_filters(congress=congress, bill_ids=bill_ids), + *_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)), + *_eligible_vote_filters(), + _is_scorable_position(normalized_position_expression(VoteRecord.position)), + ) + ) or 0 + total_scoped_votes = session.scalar( + select(func.count(func.distinct(Vote.id))) + .select_from(Vote) + .join(VoteClassification, VoteClassification.vote_id == Vote.id) + .join( + VoteMeasureLink, + and_( + VoteMeasureLink.vote_id == Vote.id, + VoteMeasureLink.role == VoteMeasureRole.VOTED_ON, + ), + ) + .where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids)) + ) or 0 + finalize_score_run( + session, + score_run=score_run, + included_vote_count=included_vote_count, + excluded_vote_count=max(total_scoped_votes - included_vote_count, 0), + ) + session.commit() + logger.info("Stored %d legislator topic score rows", written) + + +def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None: + logger.info( + "Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d " + "vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d", + diagnostics.bill_topic_rows, + diagnostics.linked_vote_rows, + diagnostics.vote_record_rows, + diagnostics.topic_vote_links, + diagnostics.scorable_vote_records, + ) + if diagnostics.bill_topic_rows == 0: + logger.warning( + "No extracted bill topics matched the score scope. Run " + "pipelines.tools.extract_bill_topics after bill summarization." + ) + elif diagnostics.linked_vote_rows == 0: + logger.warning("No direct substantive text votes matched the score scope.") + elif diagnostics.vote_record_rows == 0: + logger.warning("No individual vote records matched the score scope.") + elif diagnostics.topic_vote_links == 0: + logger.warning( + "Bill topics exist, but none are attached to bills that have eligible scored votes." + ) + elif diagnostics.scorable_vote_records == 0: + logger.warning( + "Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions." + ) + + +if __name__ == "__main__": + typer.run(main) diff --git a/pipelines/jobs/extract_bill_topics.py b/pipelines/jobs/extract_bill_topics.py new file mode 100644 index 0000000..c57562e --- /dev/null +++ b/pipelines/jobs/extract_bill_topics.py @@ -0,0 +1,682 @@ +"""Extract bill topics from bill text using a configurable topic catalog.""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any, Sequence + +import httpx +import typer +from sqlalchemy import ColumnElement, Select, delete, exists, func, select +from sqlalchemy.orm import Session, selectinload + +from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.congress import ( + Bill, + BillText, + BillTopic, + BillTopicPosition, + SubjectType, + VoteClassification, + VoteRelationship, + VoteTextTarget, +) + +logger = logging.getLogger(__name__) + +OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE" +OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions" +REQUEST_TIMEOUT_SECONDS = 60 +DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json" + + +class TopicExtractionError(RuntimeError): + """Raised when a topic extraction request or response is invalid.""" + + +@dataclass(frozen=True) +class TopicCatalog: + """Loaded topic catalog with categories for prompting and flat candidates.""" + + topics_by_category: dict[str, list[str]] + candidate_topics: list[str] + + +@dataclass(frozen=True) +class TopicExtractionDiagnostics: + """Counts for the bill summary inputs needed by topic extraction.""" + + bill_rows: int + bill_text_rows: int + summarized_bill_text_rows: int + bills_with_summaries: int + bill_topic_rows: int + selected_bills: int + + +@dataclass(frozen=True) +class ExtractedBillTopic: + """One extracted bill topic and yes-vote stance.""" + + topic: str + support_position: BillTopicPosition + confidence: float | None = None + evidence: str | None = None + + +def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None: + """Pick one summarized bill_text row from the already-loaded relationship.""" + for bill_text in bill.bill_texts: + if bill_text.summary and bill_text.summary.strip(): + return bill_text + return None + + +def normalize_topic_label(value: str) -> str: + """Normalize a topic label for storage, comparison, and de-duping.""" + normalized = value.strip().strip("\"'") + normalized = normalized.strip().rstrip(".").strip() + return re.sub(r"\s+", " ", normalized).lower() + + +def load_topic_catalog(path: Path | None = None) -> TopicCatalog: + """Load, validate, normalize, and flatten the bill topic catalog.""" + topics_path = path or DEFAULT_TOPICS_PATH + try: + raw = json.loads(topics_path.read_text()) + except FileNotFoundError as exc: + msg = f"Topic catalog not found: {topics_path}" + raise TopicExtractionError(msg) from exc + except json.JSONDecodeError as exc: + msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}" + raise TopicExtractionError(msg) from exc + + if not isinstance(raw, dict): + msg = "Topic catalog root must be an object mapping category names to lists" + raise TopicExtractionError(msg) + + topics_by_category: dict[str, list[str]] = {} + candidate_topics: list[str] = [] + seen_topics: set[str] = set() + + for category, topics in raw.items(): + if not isinstance(category, str) or not category.strip(): + msg = "Topic catalog category names must be non-empty strings" + raise TopicExtractionError(msg) + if not isinstance(topics, list): + msg = f"Topic catalog category {category!r} must contain a list" + raise TopicExtractionError(msg) + + normalized_topics: list[str] = [] + for topic in topics: + if not isinstance(topic, str): + msg = f"Topic catalog category {category!r} contains a non-string topic" + raise TopicExtractionError(msg) + normalized_topic = normalize_topic_label(topic) + if not normalized_topic: + msg = f"Topic catalog category {category!r} contains a blank topic" + raise TopicExtractionError(msg) + if normalized_topic in seen_topics: + continue + seen_topics.add(normalized_topic) + normalized_topics.append(normalized_topic) + candidate_topics.append(normalized_topic) + + topics_by_category[category.strip()] = normalized_topics + + return TopicCatalog( + topics_by_category=topics_by_category, + candidate_topics=candidate_topics, + ) + + +def build_topic_extraction_messages( + *, + bill: Bill, + bill_text: str, + candidate_topics: Sequence[str], +) -> list[dict[str, str]]: + """Build GPT messages for extracting a bill's scored topics.""" + normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics] + candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates) + metadata = "\n".join( + ( + f"Congress: {bill.congress}", + f"Bill: {bill.bill_type} {bill.number}", + f"Title: {bill.title_short or bill.title or bill.official_title or ''}", + f"Top subject term: {bill.subjects_top_term or ''}", + ) + ) + + system_prompt = ( + "You extract policy topics from U.S. congressional bills.\n" + 'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n' + 'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n' + 'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n' + "Select only topics from the provided candidate topic list.\n" + "Omit topics that are not materially addressed by the bill.\n" + "Return strict JSON only, with this shape:\n" + '{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}' + ) + user_prompt = "\n\n".join( + ( + "BILL METADATA:", + metadata, + "CANDIDATE TOPICS:", + candidate_list, + "BILL TEXT:", + bill_text, + ) + ) + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + +def call_openai_topic_extraction( + *, + openai_config: OpenAIConfig, + messages: list[dict[str, str]], +) -> str: + """Call GPT and return the assistant message content.""" + + response = httpx.post( + openai_config.openai_chat_completions_url, + headers={ + "Authorization": f"Bearer {openai_config.api_key}", + "OpenAI-Project": openai_config.openai_project_id, + "Content-Type": "application/json", + }, + json={ + "model": "gpt-5.4-mini", + "messages": messages, + }, + timeout=openai_config.timeout_seconds, + ) + response.raise_for_status() + return extract_message_content(response.json()) + + +def extract_message_content(data: dict[str, Any]) -> str: + """Extract message content from a chat-completions response body.""" + choices = data.get("choices") + if not isinstance(choices, list) or not choices: + msg = "Chat completion response did not contain choices" + raise TopicExtractionError(msg) + + first = choices[0] + if not isinstance(first, dict): + msg = "Chat completion choice must be an object" + raise TopicExtractionError(msg) + + message = first.get("message") + if isinstance(message, dict) and isinstance(message.get("content"), str): + return message["content"] + if isinstance(first.get("text"), str): + return first["text"] + + msg = "Chat completion response did not contain message content" + raise TopicExtractionError(msg) + + +def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]: + """Parse, normalize, validate, and de-dupe a topic extraction response.""" + payload = _load_json_response(response_text) + topics = payload.get("topics") + if not isinstance(topics, list): + msg = "Topic extraction response must contain a topics list" + raise TopicExtractionError(msg) + + deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {} + for item in topics: + if not isinstance(item, dict): + msg = "Topic extraction response topics must be objects" + raise TopicExtractionError(msg) + + raw_topic = _extract_topic_label(item) + topic = normalize_topic_label(raw_topic) + if not topic: + msg = "Topic extraction response topic must not be blank" + raise TopicExtractionError(msg) + + raw_position = item.get("support_position") + try: + support_position = BillTopicPosition(raw_position) + except ValueError as exc: + msg = f"Invalid support_position: {raw_position!r}" + raise TopicExtractionError(msg) from exc + + confidence = _parse_confidence(item.get("confidence")) + evidence = item.get("evidence") + if evidence is not None and not isinstance(evidence, str): + evidence = str(evidence) + + extracted = ExtractedBillTopic( + topic=topic, + support_position=support_position, + confidence=confidence, + evidence=evidence, + ) + key = (topic, support_position) + existing = deduped.get(key) + if existing is None or _confidence_rank(extracted) > _confidence_rank(existing): + deduped[key] = extracted + + return list(deduped.values()) + + +def extract_topics_for_bill_text( + *, + openai_config: OpenAIConfig, + bill: Bill, + text: str, + candidate_topics: Sequence[str], +) -> list[ExtractedBillTopic]: + """Extract accepted catalog topics for a bill text string.""" + normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics} + messages = build_topic_extraction_messages( + bill=bill, + bill_text=text, + candidate_topics=sorted(normalized_candidates), + ) + response_text = call_openai_topic_extraction( + openai_config=openai_config, + messages=messages, + ) + extracted_topics = parse_topic_extraction_response(response_text) + return [topic for topic in extracted_topics if topic.topic in normalized_candidates] + + +def store_bill_topic_result( + *, + session: Session, + bill: Bill, + topics: Sequence[ExtractedBillTopic], + replace_existing: bool = True, +) -> None: + """Store extracted topics for one bill.""" + if replace_existing: + session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id)) + + for topic in topics: + session.add( + BillTopic( + bill_id=bill.id, + topic=normalize_topic_label(topic.topic), + support_position=topic.support_position, + ) + ) + + +def create_select_bills_for_topic_extraction( + congress: int | None = None, + bill_ids: list[int] | None = None, + bill_text_ids: list[int] | None = None, + with_votes_only: bool = False, + force: bool = False, + limit: int | None = None, +) -> Select[tuple[Bill]]: + """Select bill rows that have summarized bill_text rows for topic extraction.""" + has_summary = (BillText.summary.is_not(None), BillText.summary != "") + summarized_text_filters: list[ColumnElement[bool]] = [ + BillText.bill_id == Bill.id, + *has_summary, + ] + if with_votes_only: + summarized_text_filters.append( + exists( + select(VoteTextTarget.vote_id) + .join( + VoteClassification, + VoteClassification.vote_id == VoteTextTarget.vote_id, + ) + .where( + VoteTextTarget.voted_text_version_id == BillText.id, + VoteClassification.subject_type == SubjectType.MEASURE, + VoteClassification.vote_relationship + == VoteRelationship.DIRECT_TEXT_VOTE, + VoteClassification.is_direct_vote_on_legislative_text.is_(True), + VoteClassification.is_substantive_policy_vote.is_(True), + VoteClassification.is_special_rule.is_(False), + ) + ) + ) + summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters)) + stmt = ( + select(Bill) + .where(summarized_text_exists) + .options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:]))) + .order_by(Bill.id) + ) + if congress is not None: + stmt = stmt.where(Bill.congress == congress) + if bill_ids: + stmt = stmt.where(Bill.id.in_(bill_ids)) + if bill_text_ids: + selected_text_exists = exists( + select(BillText.id).where( + BillText.bill_id == Bill.id, + BillText.id.in_(bill_text_ids), + *summarized_text_filters[1:], + ) + ) + stmt = stmt.where(selected_text_exists) + if not force: + stmt = stmt.where( + ~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id)) + ) + if limit is not None: + stmt = stmt.limit(limit) + return stmt + + +def collect_topic_extraction_diagnostics( + session: Session, + *, + congress: int | None = None, + bill_ids: list[int] | None = None, + bill_text_ids: list[int] | None = None, + with_votes_only: bool = False, + force: bool = False, + limit: int | None = None, +) -> TopicExtractionDiagnostics: + """Count topic extraction inputs for explaining empty selections.""" + bill_filters = [] + bill_text_filters: list[ColumnElement[bool]] = [] + if congress is not None: + bill_filters.append(Bill.congress == congress) + if bill_ids: + bill_filters.append(Bill.id.in_(bill_ids)) + bill_text_filters.append(BillText.bill_id.in_(bill_ids)) + if bill_text_ids: + bill_text_filters.append(BillText.id.in_(bill_text_ids)) + if with_votes_only: + bill_text_filters.append( + exists( + select(VoteTextTarget.vote_id) + .join( + VoteClassification, + VoteClassification.vote_id == VoteTextTarget.vote_id, + ) + .where( + VoteTextTarget.voted_text_version_id == BillText.id, + VoteClassification.subject_type == SubjectType.MEASURE, + VoteClassification.vote_relationship + == VoteRelationship.DIRECT_TEXT_VOTE, + VoteClassification.is_direct_vote_on_legislative_text.is_(True), + VoteClassification.is_substantive_policy_vote.is_(True), + VoteClassification.is_special_rule.is_(False), + ) + ) + ) + + has_summary = (BillText.summary.is_not(None), BillText.summary != "") + summary_filters = [*bill_text_filters, *has_summary] + + bills_with_summaries = session.scalar( + select(func.count(func.distinct(Bill.id))) + .select_from(Bill) + .join(BillText, BillText.bill_id == Bill.id) + .where(*bill_filters, *summary_filters) + ) + selected_bills = session.scalar( + select(func.count()).select_from( + create_select_bills_for_topic_extraction( + congress=congress, + bill_ids=bill_ids, + bill_text_ids=bill_text_ids, + with_votes_only=with_votes_only, + force=force, + limit=limit, + ).subquery() + ) + ) + + return TopicExtractionDiagnostics( + bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0, + bill_text_rows=_count_bill_texts( + session, + bill_filters=bill_filters, + bill_text_filters=bill_text_filters, + ), + summarized_bill_text_rows=_count_bill_texts( + session, + bill_filters=bill_filters, + bill_text_filters=summary_filters, + ), + bills_with_summaries=bills_with_summaries or 0, + bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0, + selected_bills=selected_bills or 0, + ) + + +def _load_json_response(response_text: str) -> dict[str, Any]: + text = response_text.strip() + fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL) + if fenced: + text = fenced.group(1).strip() + + try: + payload = json.loads(text) + except json.JSONDecodeError as exc: + msg = f"Topic extraction response is not valid JSON: {exc}" + raise TopicExtractionError(msg) from exc + if not isinstance(payload, dict): + msg = "Topic extraction response must be a JSON object" + raise TopicExtractionError(msg) + return payload + + +def _parse_confidence(raw: Any) -> float | None: + if raw is None: + return None + try: + return float(raw) + except (TypeError, ValueError) as exc: + msg = f"Invalid confidence: {raw!r}" + raise TopicExtractionError(msg) from exc + + +def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]: + if topic.confidence is None: + return (0, 0.0) + return (1, topic.confidence) + + +def _extract_topic_label(item: dict[str, Any]) -> str: + raw_topic = item.get("topic") + if isinstance(raw_topic, str): + return raw_topic + if isinstance(raw_topic, dict): + for key in ("topic", "label", "name", "title"): + value = raw_topic.get(key) + if isinstance(value, str): + return value + + msg = "Topic extraction response topic must be a string" + raise TopicExtractionError(msg) + + +def _count_bill_texts( + session: Session, + *, + bill_filters: Sequence[ColumnElement[bool]], + bill_text_filters: Sequence[ColumnElement[bool]], +) -> int: + stmt = select(func.count(BillText.id)) + if bill_filters: + stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters) + return session.scalar(stmt.where(*bill_text_filters)) or 0 + + +def main( + topics_path: Annotated[ + Path, typer.Option(help="Path to congressional issue topic JSON.") + ] = DEFAULT_TOPICS_PATH, + congress: Annotated[ + int | None, typer.Option(help="Only process one Congress.") + ] = None, + bill_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-id", + help="Only process one internal bill.id. Repeat for multiple bills.", + ), + ] = None, + bill_text_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-text-id", + help="Only process one internal bill_text.id. Repeat for multiple rows.", + ), + ] = None, + with_votes_only: Annotated[ + bool, + typer.Option( + "--with-votes-only", + help="Only process summarized bill_text rows linked to at least one vote.", + ), + ] = True, + limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None, + force: Annotated[ + bool, + typer.Option(help="Regenerate topics for bills that already have topics."), + ] = False, + dry_run: Annotated[ + bool, + typer.Option(help="Select bills and print diagnostics without calling OpenAI."), + ] = False, + diagnose: Annotated[ + bool, + typer.Option(help="Log input-stage counts before processing."), + ] = False, + log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", +) -> None: + """CLI entrypoint for generating and storing bill topics.""" + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + topic_catalog = load_topic_catalog(topics_path) + logger.info( + "Loaded %d candidate topics from %s", + len(topic_catalog.candidate_topics), + topics_path, + ) + + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + with Session(engine) as session: + if diagnose or dry_run: + diagnostics = collect_topic_extraction_diagnostics( + session, + congress=congress, + bill_ids=bill_ids, + bill_text_ids=bill_text_ids, + with_votes_only=with_votes_only, + force=force, + limit=limit, + ) + _log_topic_extraction_diagnostics(diagnostics) + if dry_run: + return + + openai_config = get_openai_config() + + stmt = create_select_bills_for_topic_extraction( + congress=congress, + bill_ids=bill_ids, + bill_text_ids=bill_text_ids, + with_votes_only=with_votes_only, + force=force, + limit=limit, + ) + bills = session.scalars(stmt).all() + logger.info("Selected %d bills for topic extraction", len(bills)) + + written = 0 + failed = 0 + for index, bill in enumerate(bills, 1): + bill_text = _select_bill_text_for_topic_extraction(bill) + if bill_text is None: + logger.warning("Skipping bill id=%s: no usable summary", bill.id) + continue + summary = bill_text.summary.strip() + + try: + extracted_topics = extract_topics_for_bill_text( + openai_config=openai_config, + bill=bill, + text=summary, + candidate_topics=topic_catalog.candidate_topics, + ) + except (httpx.HTTPError, TopicExtractionError): + failed += 1 + logger.exception( + "Skipping bill id=%s after topic extraction failure", bill.id + ) + continue + + store_bill_topic_result( + session=session, + bill=bill, + topics=extracted_topics, + replace_existing=True, + ) + written += 1 + if index % 100 == 0: + session.commit() + logger.info( + "Stored %d topics for bill id=%s", + len(extracted_topics), + bill.id, + ) + + session.commit() + logger.info( + "Done: stored topic results for %d bills; failed %d bills", + written, + failed, + ) + + +def _log_topic_extraction_diagnostics( + diagnostics: TopicExtractionDiagnostics, +) -> None: + logger.info( + "Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d " + "summarized_bill_text_rows=%d bills_with_summaries=%d " + "bill_topic_rows=%d selected_bills=%d", + diagnostics.bill_rows, + diagnostics.bill_text_rows, + diagnostics.summarized_bill_text_rows, + diagnostics.bills_with_summaries, + diagnostics.bill_topic_rows, + diagnostics.selected_bills, + ) + if diagnostics.bill_rows == 0: + logger.warning("No bills matched the topic extraction scope.") + elif diagnostics.bill_text_rows == 0: + logger.warning("No bill_text rows matched the topic extraction scope.") + elif diagnostics.summarized_bill_text_rows == 0: + logger.warning( + "No summarized bill_text rows matched the topic extraction scope. " + "Run pipelines.tools.summarize_bills first." + ) + elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0: + logger.warning( + "No bills selected because matching bills already have topics. " + "Use --force to regenerate them." + ) + elif diagnostics.selected_bills == 0: + logger.warning("No bills selected for topic extraction.") + + +if __name__ == "__main__": + typer.run(main) diff --git a/pipelines/jobs/summarize_bills.py b/pipelines/jobs/summarize_bills.py new file mode 100644 index 0000000..871d4a2 --- /dev/null +++ b/pipelines/jobs/summarize_bills.py @@ -0,0 +1,309 @@ +"""Summarize bill_text rows with GPT-5 and store results in the database.""" + +from __future__ import annotations + +import logging +import tomllib +from os import getenv +from typing import Annotated, Any + +import httpx +import typer +from sqlalchemy import Select, exists, or_, select +from sqlalchemy.orm import Session, selectinload + +from tiktoken import get_encoding + + +from pipelines.config import get_config_dir +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.congress import ( + Bill, + BillText, + SubjectType, + VoteClassification, + VoteRelationship, + VoteTextTarget, +) +from pipelines.tools.bill_token_compression import compress_bill_text + +logger = logging.getLogger(__name__) + +OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions" +OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE" +REQUEST_TIMEOUT_SECONDS = 60 + + +def load_summarization_prompts( + section: str = "summarization", +) -> dict[str, str]: + summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml" + + return tomllib.loads(summarization_prompts.read_text())[section] + + +class BillSummaryError(RuntimeError): + """Raised when a bill summary request or response is invalid.""" + + +def call_openai_summary( + *, + model: str, + messages: list[dict[str, str]], +) -> str: + """Call GPT and return the assistant message content.""" + api_key = getenv("CLOSEDAI_TOKEN") + if not api_key: + msg = "CLOSEDAI_TOKEN is required" + raise BillSummaryError(msg) + + response = httpx.post( + OPENAI_CHAT_COMPLETIONS_URL, + headers={ + "Authorization": f"Bearer {api_key}", + "OpenAI-Project": OPENAI_PROJECT_ID, + "Content-Type": "application/json", + }, + json={ + "model": model, + "messages": messages, + }, + timeout=REQUEST_TIMEOUT_SECONDS, + ) + logger.info(f"{response.text=}") + response.raise_for_status() + return extract_message_content(response.json()) + + +def build_bill_summary_messages( + *, + bill_text: BillText, + summarization_prompts: dict[str, str], +) -> list[dict[str, str]]: + """Build the GPT prompt messages plus compressed text and user prompt.""" + if not bill_text.text_content: + msg = f"bill_text id={bill_text.id} has no text_content" + raise BillSummaryError(msg) + + compressed_text = compress_bill_text(bill_text.text_content) + if not compressed_text: + msg = f"bill_text id={bill_text.id} has no summarizable text_content" + raise BillSummaryError(msg) + + user_prompt = summarization_prompts["user_template"].format( + text_content=compressed_text + ) + + user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt)) + logger.info(f"{user_prompt_tokens=}") + + messages = [ + {"role": "system", "content": summarization_prompts["system_prompt"]}, + { + "role": "user", + "content": user_prompt, + }, + ] + return messages, user_prompt_tokens + + +def summarize_bill_text( + *, + model: str, + bill_text: BillText, + summarization_prompts: dict[str, str], +) -> str: + """Generate and return a summary for one bill_text row.""" + messages, user_prompt_tokens = build_bill_summary_messages( + bill_text=bill_text, + summarization_prompts=summarization_prompts, + ) + # This may only be for gpt-5.4 mini I need to read the docs + if user_prompt_tokens > 272000: + msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)" + logger.warning(msg) + return None + + summary = call_openai_summary( + model=model, + messages=messages, + ).strip() + if not summary: + msg = f"Model returned an empty summary for bill_text id={bill_text.id}" + raise BillSummaryError(msg) + return summary + + +def store_bill_summary_result( + *, + bill_text: BillText, + summary: str, + model: str, +) -> None: + """Store a generated summary and the prompt/model metadata that produced it.""" + bill_text.summary = summary + bill_text.summarization_model = model + bill_text.summarization_system_prompt_version = "v1.2" + bill_text.summarization_user_prompt_version = "v1" + + +def create_select_bill_texts_for_summarization( + congress: int | None = None, + bill_ids: list[int] | None = None, + bill_text_ids: list[int] | None = None, + with_votes_only: bool = False, + force: bool = False, + limit: int | None = None, +) -> Select: + """Select bill_text rows that have source text and need summaries.""" + stmt = ( + select(BillText) + .join(Bill, Bill.id == BillText.bill_id) + .where(BillText.text_content.is_not(None), BillText.text_content != "") + .options(selectinload(BillText.bill)) + .order_by(BillText.id) + ) + if congress is not None: + stmt = stmt.where(Bill.congress == congress) + if bill_ids: + stmt = stmt.where(BillText.bill_id.in_(bill_ids)) + if bill_text_ids: + stmt = stmt.where(BillText.id.in_(bill_text_ids)) + if with_votes_only: + stmt = stmt.where( + exists( + select(VoteTextTarget.vote_id) + .join( + VoteClassification, + VoteClassification.vote_id == VoteTextTarget.vote_id, + ) + .where( + VoteTextTarget.voted_text_version_id == BillText.id, + VoteClassification.subject_type == SubjectType.MEASURE, + VoteClassification.vote_relationship + == VoteRelationship.DIRECT_TEXT_VOTE, + VoteClassification.is_direct_vote_on_legislative_text.is_(True), + VoteClassification.is_substantive_policy_vote.is_(True), + VoteClassification.is_special_rule.is_(False), + ) + ) + ) + if not force: + stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == "")) + if limit is not None: + stmt = stmt.limit(limit) + return stmt + + +def extract_message_content(data: dict[str, Any]) -> str: + """Extract message content from a chat-completions response body.""" + choices = data.get("choices") + if not isinstance(choices, list) or not choices: + msg = "Chat completion response did not contain choices" + raise BillSummaryError(msg) + + first = choices[0] + if not isinstance(first, dict): + msg = "Chat completion choice must be an object" + raise BillSummaryError(msg) + + message = first.get("message") + if isinstance(message, dict) and isinstance(message.get("content"), str): + return message["content"] + if isinstance(first.get("text"), str): + return first["text"] + + msg = "Chat completion response did not contain message content" + raise BillSummaryError(msg) + + +def main( + model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini", + congress: Annotated[ + int | None, typer.Option(help="Only process one Congress.") + ] = None, + bill_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-id", + help="Only process one internal bill.id. Repeat for multiple bills.", + ), + ] = None, + bill_text_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-text-id", + help="Only process one internal bill_text.id. Repeat for multiple rows.", + ), + ] = None, + with_votes_only: Annotated[ + bool, + typer.Option( + "--with-votes-only", + help="Only process bill_text rows linked to at least one vote.", + ), + ] = False, + limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None, + force: Annotated[ + bool, + typer.Option(help="Regenerate summaries for rows that already have a summary."), + ] = False, + dry_run: Annotated[ + bool, typer.Option(help="Print summaries without writing them to the database.") + ] = False, + log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", +) -> None: + """CLI entrypoint for generating and storing bill summaries.""" + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + if not getenv("CLOSEDAI_TOKEN"): + message = "CLOSEDAI_TOKEN is required" + raise typer.BadParameter(message) + + summarization_prompts = load_summarization_prompts() + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + with Session(engine) as session: + stmt = create_select_bill_texts_for_summarization( + congress=congress, + bill_ids=bill_ids, + bill_text_ids=bill_text_ids, + with_votes_only=with_votes_only, + force=force, + limit=limit, + ) + bill_texts = session.scalars(stmt).all() + logger.info("Selected %d bill_text rows for summarization", len(bill_texts)) + + written = 0 + for index, bill_text in enumerate(bill_texts, 1): + summary = summarize_bill_text( + model=model, + bill_text=bill_text, + summarization_prompts=summarization_prompts, + ) + if summary is None: + logger.warning("Skipping bill_text id=%s", bill_text.id) + continue + store_bill_summary_result( + bill_text=bill_text, + summary=summary, + model=model, + ) + if index % 100 == 0: + session.commit() + written += 1 + session.commit() + logger.info("Stored summary for bill_text id=%s", bill_text.id) + + logger.info("Done: stored %d summaries", written) + + +def cli() -> None: + """Typer entry point.""" + typer.run(main) + + +if __name__ == "__main__": + cli() -- 2.54.0