adding calculate_legislator_scores.py summarize_bills.py and extract_bill_topics.py
This commit is contained in:
@@ -0,0 +1,682 @@
|
||||
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
BillTopic,
|
||||
BillTopicPosition,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
VoteTextTarget,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||
REQUEST_TIMEOUT_SECONDS = 60
|
||||
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||
|
||||
|
||||
class TopicExtractionError(RuntimeError):
|
||||
"""Raised when a topic extraction request or response is invalid."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicCatalog:
|
||||
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||
|
||||
topics_by_category: dict[str, list[str]]
|
||||
candidate_topics: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicExtractionDiagnostics:
|
||||
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||
|
||||
bill_rows: int
|
||||
bill_text_rows: int
|
||||
summarized_bill_text_rows: int
|
||||
bills_with_summaries: int
|
||||
bill_topic_rows: int
|
||||
selected_bills: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtractedBillTopic:
|
||||
"""One extracted bill topic and yes-vote stance."""
|
||||
|
||||
topic: str
|
||||
support_position: BillTopicPosition
|
||||
confidence: float | None = None
|
||||
evidence: str | None = None
|
||||
|
||||
|
||||
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||
for bill_text in bill.bill_texts:
|
||||
if bill_text.summary and bill_text.summary.strip():
|
||||
return bill_text
|
||||
return None
|
||||
|
||||
|
||||
def normalize_topic_label(value: str) -> str:
|
||||
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||
normalized = value.strip().strip("\"'")
|
||||
normalized = normalized.strip().rstrip(".").strip()
|
||||
return re.sub(r"\s+", " ", normalized).lower()
|
||||
|
||||
|
||||
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||
topics_path = path or DEFAULT_TOPICS_PATH
|
||||
try:
|
||||
raw = json.loads(topics_path.read_text())
|
||||
except FileNotFoundError as exc:
|
||||
msg = f"Topic catalog not found: {topics_path}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
topics_by_category: dict[str, list[str]] = {}
|
||||
candidate_topics: list[str] = []
|
||||
seen_topics: set[str] = set()
|
||||
|
||||
for category, topics in raw.items():
|
||||
if not isinstance(category, str) or not category.strip():
|
||||
msg = "Topic catalog category names must be non-empty strings"
|
||||
raise TopicExtractionError(msg)
|
||||
if not isinstance(topics, list):
|
||||
msg = f"Topic catalog category {category!r} must contain a list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
normalized_topics: list[str] = []
|
||||
for topic in topics:
|
||||
if not isinstance(topic, str):
|
||||
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||
raise TopicExtractionError(msg)
|
||||
normalized_topic = normalize_topic_label(topic)
|
||||
if not normalized_topic:
|
||||
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||
raise TopicExtractionError(msg)
|
||||
if normalized_topic in seen_topics:
|
||||
continue
|
||||
seen_topics.add(normalized_topic)
|
||||
normalized_topics.append(normalized_topic)
|
||||
candidate_topics.append(normalized_topic)
|
||||
|
||||
topics_by_category[category.strip()] = normalized_topics
|
||||
|
||||
return TopicCatalog(
|
||||
topics_by_category=topics_by_category,
|
||||
candidate_topics=candidate_topics,
|
||||
)
|
||||
|
||||
|
||||
def build_topic_extraction_messages(
|
||||
*,
|
||||
bill: Bill,
|
||||
bill_text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build GPT messages for extracting a bill's scored topics."""
|
||||
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||
metadata = "\n".join(
|
||||
(
|
||||
f"Congress: {bill.congress}",
|
||||
f"Bill: {bill.bill_type} {bill.number}",
|
||||
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||
)
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You extract policy topics from U.S. congressional bills.\n"
|
||||
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||
"Select only topics from the provided candidate topic list.\n"
|
||||
"Omit topics that are not materially addressed by the bill.\n"
|
||||
"Return strict JSON only, with this shape:\n"
|
||||
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||
)
|
||||
user_prompt = "\n\n".join(
|
||||
(
|
||||
"BILL METADATA:",
|
||||
metadata,
|
||||
"CANDIDATE TOPICS:",
|
||||
candidate_list,
|
||||
"BILL TEXT:",
|
||||
bill_text,
|
||||
)
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
def call_openai_topic_extraction(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Call GPT and return the assistant message content."""
|
||||
|
||||
response = httpx.post(
|
||||
openai_config.openai_chat_completions_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {openai_config.api_key}",
|
||||
"OpenAI-Project": openai_config.openai_project_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "gpt-5.4-mini",
|
||||
"messages": messages,
|
||||
},
|
||||
timeout=openai_config.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return extract_message_content(response.json())
|
||||
|
||||
|
||||
def extract_message_content(data: dict[str, Any]) -> str:
|
||||
"""Extract message content from a chat-completions response body."""
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat completion response did not contain choices"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat completion choice must be an object"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
message = first.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
if isinstance(first.get("text"), str):
|
||||
return first["text"]
|
||||
|
||||
msg = "Chat completion response did not contain message content"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||
payload = _load_json_response(response_text)
|
||||
topics = payload.get("topics")
|
||||
if not isinstance(topics, list):
|
||||
msg = "Topic extraction response must contain a topics list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||
for item in topics:
|
||||
if not isinstance(item, dict):
|
||||
msg = "Topic extraction response topics must be objects"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_topic = _extract_topic_label(item)
|
||||
topic = normalize_topic_label(raw_topic)
|
||||
if not topic:
|
||||
msg = "Topic extraction response topic must not be blank"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_position = item.get("support_position")
|
||||
try:
|
||||
support_position = BillTopicPosition(raw_position)
|
||||
except ValueError as exc:
|
||||
msg = f"Invalid support_position: {raw_position!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
confidence = _parse_confidence(item.get("confidence"))
|
||||
evidence = item.get("evidence")
|
||||
if evidence is not None and not isinstance(evidence, str):
|
||||
evidence = str(evidence)
|
||||
|
||||
extracted = ExtractedBillTopic(
|
||||
topic=topic,
|
||||
support_position=support_position,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
)
|
||||
key = (topic, support_position)
|
||||
existing = deduped.get(key)
|
||||
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||
deduped[key] = extracted
|
||||
|
||||
return list(deduped.values())
|
||||
|
||||
|
||||
def extract_topics_for_bill_text(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
bill: Bill,
|
||||
text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[ExtractedBillTopic]:
|
||||
"""Extract accepted catalog topics for a bill text string."""
|
||||
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||
messages = build_topic_extraction_messages(
|
||||
bill=bill,
|
||||
bill_text=text,
|
||||
candidate_topics=sorted(normalized_candidates),
|
||||
)
|
||||
response_text = call_openai_topic_extraction(
|
||||
openai_config=openai_config,
|
||||
messages=messages,
|
||||
)
|
||||
extracted_topics = parse_topic_extraction_response(response_text)
|
||||
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||
|
||||
|
||||
def store_bill_topic_result(
|
||||
*,
|
||||
session: Session,
|
||||
bill: Bill,
|
||||
topics: Sequence[ExtractedBillTopic],
|
||||
replace_existing: bool = True,
|
||||
) -> None:
|
||||
"""Store extracted topics for one bill."""
|
||||
if replace_existing:
|
||||
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||
|
||||
for topic in topics:
|
||||
session.add(
|
||||
BillTopic(
|
||||
bill_id=bill.id,
|
||||
topic=normalize_topic_label(topic.topic),
|
||||
support_position=topic.support_position,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_select_bills_for_topic_extraction(
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select[tuple[Bill]]:
|
||||
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||
BillText.bill_id == Bill.id,
|
||||
*has_summary,
|
||||
]
|
||||
if with_votes_only:
|
||||
summarized_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||
stmt = (
|
||||
select(Bill)
|
||||
.where(summarized_text_exists)
|
||||
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||
.order_by(Bill.id)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
selected_text_exists = exists(
|
||||
select(BillText.id).where(
|
||||
BillText.bill_id == Bill.id,
|
||||
BillText.id.in_(bill_text_ids),
|
||||
*summarized_text_filters[1:],
|
||||
)
|
||||
)
|
||||
stmt = stmt.where(selected_text_exists)
|
||||
if not force:
|
||||
stmt = stmt.where(
|
||||
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def collect_topic_extraction_diagnostics(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> TopicExtractionDiagnostics:
|
||||
"""Count topic extraction inputs for explaining empty selections."""
|
||||
bill_filters = []
|
||||
bill_text_filters: list[ColumnElement[bool]] = []
|
||||
if congress is not None:
|
||||
bill_filters.append(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
bill_filters.append(Bill.id.in_(bill_ids))
|
||||
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||
if with_votes_only:
|
||||
bill_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summary_filters = [*bill_text_filters, *has_summary]
|
||||
|
||||
bills_with_summaries = session.scalar(
|
||||
select(func.count(func.distinct(Bill.id)))
|
||||
.select_from(Bill)
|
||||
.join(BillText, BillText.bill_id == Bill.id)
|
||||
.where(*bill_filters, *summary_filters)
|
||||
)
|
||||
selected_bills = session.scalar(
|
||||
select(func.count()).select_from(
|
||||
create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
).subquery()
|
||||
)
|
||||
)
|
||||
|
||||
return TopicExtractionDiagnostics(
|
||||
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||
bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=bill_text_filters,
|
||||
),
|
||||
summarized_bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=summary_filters,
|
||||
),
|
||||
bills_with_summaries=bills_with_summaries or 0,
|
||||
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||
selected_bills=selected_bills or 0,
|
||||
)
|
||||
|
||||
|
||||
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||
text = response_text.strip()
|
||||
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||
if fenced:
|
||||
text = fenced.group(1).strip()
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
if not isinstance(payload, dict):
|
||||
msg = "Topic extraction response must be a JSON object"
|
||||
raise TopicExtractionError(msg)
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_confidence(raw: Any) -> float | None:
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError) as exc:
|
||||
msg = f"Invalid confidence: {raw!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
|
||||
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||
if topic.confidence is None:
|
||||
return (0, 0.0)
|
||||
return (1, topic.confidence)
|
||||
|
||||
|
||||
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||
raw_topic = item.get("topic")
|
||||
if isinstance(raw_topic, str):
|
||||
return raw_topic
|
||||
if isinstance(raw_topic, dict):
|
||||
for key in ("topic", "label", "name", "title"):
|
||||
value = raw_topic.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
msg = "Topic extraction response topic must be a string"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def _count_bill_texts(
|
||||
session: Session,
|
||||
*,
|
||||
bill_filters: Sequence[ColumnElement[bool]],
|
||||
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||
) -> int:
|
||||
stmt = select(func.count(BillText.id))
|
||||
if bill_filters:
|
||||
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||
|
||||
|
||||
def main(
|
||||
topics_path: Annotated[
|
||||
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||
] = DEFAULT_TOPICS_PATH,
|
||||
congress: Annotated[
|
||||
int | None, typer.Option(help="Only process one Congress.")
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||
),
|
||||
] = None,
|
||||
bill_text_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-text-id",
|
||||
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||
),
|
||||
] = None,
|
||||
with_votes_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--with-votes-only",
|
||||
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||
),
|
||||
] = True,
|
||||
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||
] = False,
|
||||
diagnose: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Log input-stage counts before processing."),
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
) -> None:
|
||||
"""CLI entrypoint for generating and storing bill topics."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
topic_catalog = load_topic_catalog(topics_path)
|
||||
logger.info(
|
||||
"Loaded %d candidate topics from %s",
|
||||
len(topic_catalog.candidate_topics),
|
||||
topics_path,
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
if diagnose or dry_run:
|
||||
diagnostics = collect_topic_extraction_diagnostics(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
_log_topic_extraction_diagnostics(diagnostics)
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
openai_config = get_openai_config()
|
||||
|
||||
stmt = create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
bills = session.scalars(stmt).all()
|
||||
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||
|
||||
written = 0
|
||||
failed = 0
|
||||
for index, bill in enumerate(bills, 1):
|
||||
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||
if bill_text is None:
|
||||
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||
continue
|
||||
summary = bill_text.summary.strip()
|
||||
|
||||
try:
|
||||
extracted_topics = extract_topics_for_bill_text(
|
||||
openai_config=openai_config,
|
||||
bill=bill,
|
||||
text=summary,
|
||||
candidate_topics=topic_catalog.candidate_topics,
|
||||
)
|
||||
except (httpx.HTTPError, TopicExtractionError):
|
||||
failed += 1
|
||||
logger.exception(
|
||||
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||
)
|
||||
continue
|
||||
|
||||
store_bill_topic_result(
|
||||
session=session,
|
||||
bill=bill,
|
||||
topics=extracted_topics,
|
||||
replace_existing=True,
|
||||
)
|
||||
written += 1
|
||||
if index % 100 == 0:
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Stored %d topics for bill id=%s",
|
||||
len(extracted_topics),
|
||||
bill.id,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Done: stored topic results for %d bills; failed %d bills",
|
||||
written,
|
||||
failed,
|
||||
)
|
||||
|
||||
|
||||
def _log_topic_extraction_diagnostics(
|
||||
diagnostics: TopicExtractionDiagnostics,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||
"bill_topic_rows=%d selected_bills=%d",
|
||||
diagnostics.bill_rows,
|
||||
diagnostics.bill_text_rows,
|
||||
diagnostics.summarized_bill_text_rows,
|
||||
diagnostics.bills_with_summaries,
|
||||
diagnostics.bill_topic_rows,
|
||||
diagnostics.selected_bills,
|
||||
)
|
||||
if diagnostics.bill_rows == 0:
|
||||
logger.warning("No bills matched the topic extraction scope.")
|
||||
elif diagnostics.bill_text_rows == 0:
|
||||
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||
elif diagnostics.summarized_bill_text_rows == 0:
|
||||
logger.warning(
|
||||
"No summarized bill_text rows matched the topic extraction scope. "
|
||||
"Run pipelines.tools.summarize_bills first."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||
logger.warning(
|
||||
"No bills selected because matching bills already have topics. "
|
||||
"Use --force to regenerate them."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0:
|
||||
logger.warning("No bills selected for topic extraction.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
Reference in New Issue
Block a user