697 lines
24 KiB
Python
697 lines
24 KiB
Python
"""Extract bill topics from bill text using a configurable topic catalog."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Annotated, Any, Sequence
|
|
|
|
import httpx
|
|
import typer
|
|
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
|
from pipelines.orm.common import get_postgres_engine
|
|
from pipelines.orm.data_science_dev.congress import (
|
|
Bill,
|
|
BillText,
|
|
BillTextSummary,
|
|
BillTopic,
|
|
BillTopicPosition,
|
|
SubjectType,
|
|
VoteClassification,
|
|
VoteRelationship,
|
|
VoteTextTarget,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
|
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
|
REQUEST_TIMEOUT_SECONDS = 60
|
|
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
|
|
|
|
|
class TopicExtractionError(RuntimeError):
|
|
"""Raised when a topic extraction request or response is invalid."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TopicCatalog:
|
|
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
|
|
|
topics_by_category: dict[str, list[str]]
|
|
candidate_topics: list[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TopicExtractionDiagnostics:
|
|
"""Counts for the bill summary inputs needed by topic extraction."""
|
|
|
|
bill_rows: int
|
|
bill_text_rows: int
|
|
summarized_bill_text_rows: int
|
|
bills_with_summaries: int
|
|
bill_topic_rows: int
|
|
selected_bills: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExtractedBillTopic:
|
|
"""One extracted bill topic and yes-vote stance."""
|
|
|
|
topic: str
|
|
support_position: BillTopicPosition
|
|
confidence: float | None = None
|
|
evidence: str | None = None
|
|
|
|
|
|
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
|
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
|
for bill_text in bill.bill_texts:
|
|
summary_row = bill_text.default_summary()
|
|
if summary_row and summary_row.summary.strip():
|
|
return bill_text
|
|
return None
|
|
|
|
|
|
def _bill_text_has_summary_clause() -> ColumnElement[bool]:
|
|
"""Return a correlated EXISTS clause for bill texts with at least one summary."""
|
|
return exists(
|
|
select(BillTextSummary.id).where(BillTextSummary.bill_text_id == BillText.id)
|
|
)
|
|
|
|
|
|
def normalize_topic_label(value: str) -> str:
|
|
"""Normalize a topic label for storage, comparison, and de-duping."""
|
|
normalized = value.strip().strip("\"'")
|
|
normalized = normalized.strip().rstrip(".").strip()
|
|
return re.sub(r"\s+", " ", normalized).lower()
|
|
|
|
|
|
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
|
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
|
topics_path = path or DEFAULT_TOPICS_PATH
|
|
try:
|
|
raw = json.loads(topics_path.read_text())
|
|
except FileNotFoundError as exc:
|
|
msg = f"Topic catalog not found: {topics_path}"
|
|
raise TopicExtractionError(msg) from exc
|
|
except json.JSONDecodeError as exc:
|
|
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
|
raise TopicExtractionError(msg) from exc
|
|
|
|
if not isinstance(raw, dict):
|
|
msg = "Topic catalog root must be an object mapping category names to lists"
|
|
raise TopicExtractionError(msg)
|
|
|
|
topics_by_category: dict[str, list[str]] = {}
|
|
candidate_topics: list[str] = []
|
|
seen_topics: set[str] = set()
|
|
|
|
for category, topics in raw.items():
|
|
if not isinstance(category, str) or not category.strip():
|
|
msg = "Topic catalog category names must be non-empty strings"
|
|
raise TopicExtractionError(msg)
|
|
if not isinstance(topics, list):
|
|
msg = f"Topic catalog category {category!r} must contain a list"
|
|
raise TopicExtractionError(msg)
|
|
|
|
normalized_topics: list[str] = []
|
|
for topic in topics:
|
|
if not isinstance(topic, str):
|
|
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
|
raise TopicExtractionError(msg)
|
|
normalized_topic = normalize_topic_label(topic)
|
|
if not normalized_topic:
|
|
msg = f"Topic catalog category {category!r} contains a blank topic"
|
|
raise TopicExtractionError(msg)
|
|
if normalized_topic in seen_topics:
|
|
continue
|
|
seen_topics.add(normalized_topic)
|
|
normalized_topics.append(normalized_topic)
|
|
candidate_topics.append(normalized_topic)
|
|
|
|
topics_by_category[category.strip()] = normalized_topics
|
|
|
|
return TopicCatalog(
|
|
topics_by_category=topics_by_category,
|
|
candidate_topics=candidate_topics,
|
|
)
|
|
|
|
|
|
def build_topic_extraction_messages(
|
|
*,
|
|
bill: Bill,
|
|
bill_text: str,
|
|
candidate_topics: Sequence[str],
|
|
) -> list[dict[str, str]]:
|
|
"""Build GPT messages for extracting a bill's scored topics."""
|
|
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
|
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
|
metadata = "\n".join(
|
|
(
|
|
f"Congress: {bill.congress}",
|
|
f"Bill: {bill.bill_type} {bill.number}",
|
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
|
)
|
|
)
|
|
|
|
system_prompt = (
|
|
"You extract policy topics from U.S. congressional bills.\n"
|
|
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
|
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
|
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
|
"Select only topics from the provided candidate topic list.\n"
|
|
"Omit topics that are not materially addressed by the bill.\n"
|
|
"Return strict JSON only, with this shape:\n"
|
|
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
|
)
|
|
user_prompt = "\n\n".join(
|
|
(
|
|
"BILL METADATA:",
|
|
metadata,
|
|
"CANDIDATE TOPICS:",
|
|
candidate_list,
|
|
"BILL TEXT:",
|
|
bill_text,
|
|
)
|
|
)
|
|
return [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
|
|
def call_openai_topic_extraction(
|
|
*,
|
|
openai_config: OpenAIConfig,
|
|
messages: list[dict[str, str]],
|
|
) -> str:
|
|
"""Call GPT and return the assistant message content."""
|
|
|
|
response = httpx.post(
|
|
openai_config.openai_chat_completions_url,
|
|
headers={
|
|
"Authorization": f"Bearer {openai_config.api_key}",
|
|
"OpenAI-Project": openai_config.openai_project_id,
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={
|
|
"model": "gpt-5.4-mini",
|
|
"messages": messages,
|
|
},
|
|
timeout=openai_config.timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
return extract_message_content(response.json())
|
|
|
|
|
|
def extract_message_content(data: dict[str, Any]) -> str:
|
|
"""Extract message content from a chat-completions response body."""
|
|
choices = data.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
msg = "Chat completion response did not contain choices"
|
|
raise TopicExtractionError(msg)
|
|
|
|
first = choices[0]
|
|
if not isinstance(first, dict):
|
|
msg = "Chat completion choice must be an object"
|
|
raise TopicExtractionError(msg)
|
|
|
|
message = first.get("message")
|
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
|
return message["content"]
|
|
if isinstance(first.get("text"), str):
|
|
return first["text"]
|
|
|
|
msg = "Chat completion response did not contain message content"
|
|
raise TopicExtractionError(msg)
|
|
|
|
|
|
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
|
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
|
payload = _load_json_response(response_text)
|
|
topics = payload.get("topics")
|
|
if not isinstance(topics, list):
|
|
msg = "Topic extraction response must contain a topics list"
|
|
raise TopicExtractionError(msg)
|
|
|
|
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
|
for item in topics:
|
|
if not isinstance(item, dict):
|
|
msg = "Topic extraction response topics must be objects"
|
|
raise TopicExtractionError(msg)
|
|
|
|
raw_topic = _extract_topic_label(item)
|
|
topic = normalize_topic_label(raw_topic)
|
|
if not topic:
|
|
msg = "Topic extraction response topic must not be blank"
|
|
raise TopicExtractionError(msg)
|
|
|
|
raw_position = item.get("support_position")
|
|
try:
|
|
support_position = BillTopicPosition(raw_position)
|
|
except ValueError as exc:
|
|
msg = f"Invalid support_position: {raw_position!r}"
|
|
raise TopicExtractionError(msg) from exc
|
|
|
|
confidence = _parse_confidence(item.get("confidence"))
|
|
evidence = item.get("evidence")
|
|
if evidence is not None and not isinstance(evidence, str):
|
|
evidence = str(evidence)
|
|
|
|
extracted = ExtractedBillTopic(
|
|
topic=topic,
|
|
support_position=support_position,
|
|
confidence=confidence,
|
|
evidence=evidence,
|
|
)
|
|
key = (topic, support_position)
|
|
existing = deduped.get(key)
|
|
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
|
deduped[key] = extracted
|
|
|
|
return list(deduped.values())
|
|
|
|
|
|
def extract_topics_for_bill_text(
|
|
*,
|
|
openai_config: OpenAIConfig,
|
|
bill: Bill,
|
|
text: str,
|
|
candidate_topics: Sequence[str],
|
|
) -> list[ExtractedBillTopic]:
|
|
"""Extract accepted catalog topics for a bill text string."""
|
|
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
|
messages = build_topic_extraction_messages(
|
|
bill=bill,
|
|
bill_text=text,
|
|
candidate_topics=sorted(normalized_candidates),
|
|
)
|
|
response_text = call_openai_topic_extraction(
|
|
openai_config=openai_config,
|
|
messages=messages,
|
|
)
|
|
extracted_topics = parse_topic_extraction_response(response_text)
|
|
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
|
|
|
|
|
def store_bill_topic_result(
|
|
*,
|
|
session: Session,
|
|
bill: Bill,
|
|
topics: Sequence[ExtractedBillTopic],
|
|
replace_existing: bool = True,
|
|
) -> None:
|
|
"""Store extracted topics for one bill."""
|
|
if replace_existing:
|
|
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
|
|
|
for topic in topics:
|
|
session.add(
|
|
BillTopic(
|
|
bill_id=bill.id,
|
|
topic=normalize_topic_label(topic.topic),
|
|
support_position=topic.support_position,
|
|
)
|
|
)
|
|
|
|
|
|
def create_select_bills_for_topic_extraction(
|
|
congress: int | None = None,
|
|
bill_ids: list[int] | None = None,
|
|
bill_text_ids: list[int] | None = None,
|
|
with_votes_only: bool = False,
|
|
force: bool = False,
|
|
limit: int | None = None,
|
|
) -> Select[tuple[Bill]]:
|
|
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
|
summarized_text_filters: list[ColumnElement[bool]] = [_bill_text_has_summary_clause()]
|
|
if with_votes_only:
|
|
summarized_text_filters.append(
|
|
exists(
|
|
select(VoteTextTarget.vote_id)
|
|
.join(
|
|
VoteClassification,
|
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
|
)
|
|
.where(
|
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
|
VoteClassification.vote_relationship
|
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
|
VoteClassification.is_special_rule.is_(False),
|
|
)
|
|
)
|
|
)
|
|
summarized_text_exists = exists(
|
|
select(BillText.id).where(BillText.bill_id == Bill.id, *summarized_text_filters)
|
|
)
|
|
bill_text_loader = selectinload(Bill.bill_texts.and_(*summarized_text_filters))
|
|
stmt = (
|
|
select(Bill)
|
|
.where(summarized_text_exists)
|
|
.options(
|
|
bill_text_loader.selectinload(BillText.summaries),
|
|
bill_text_loader.selectinload(BillText.primary_summary),
|
|
)
|
|
.order_by(Bill.id)
|
|
)
|
|
if congress is not None:
|
|
stmt = stmt.where(Bill.congress == congress)
|
|
if bill_ids:
|
|
stmt = stmt.where(Bill.id.in_(bill_ids))
|
|
if bill_text_ids:
|
|
selected_text_exists = exists(
|
|
select(BillText.id).where(
|
|
BillText.bill_id == Bill.id,
|
|
BillText.id.in_(bill_text_ids),
|
|
*summarized_text_filters,
|
|
)
|
|
)
|
|
stmt = stmt.where(selected_text_exists)
|
|
if not force:
|
|
stmt = stmt.where(
|
|
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return stmt
|
|
|
|
|
|
def collect_topic_extraction_diagnostics(
|
|
session: Session,
|
|
*,
|
|
congress: int | None = None,
|
|
bill_ids: list[int] | None = None,
|
|
bill_text_ids: list[int] | None = None,
|
|
with_votes_only: bool = False,
|
|
force: bool = False,
|
|
limit: int | None = None,
|
|
) -> TopicExtractionDiagnostics:
|
|
"""Count topic extraction inputs for explaining empty selections."""
|
|
bill_filters = []
|
|
bill_text_filters: list[ColumnElement[bool]] = []
|
|
if congress is not None:
|
|
bill_filters.append(Bill.congress == congress)
|
|
if bill_ids:
|
|
bill_filters.append(Bill.id.in_(bill_ids))
|
|
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
|
if bill_text_ids:
|
|
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
|
if with_votes_only:
|
|
bill_text_filters.append(
|
|
exists(
|
|
select(VoteTextTarget.vote_id)
|
|
.join(
|
|
VoteClassification,
|
|
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
|
)
|
|
.where(
|
|
VoteTextTarget.voted_text_version_id == BillText.id,
|
|
VoteClassification.subject_type == SubjectType.MEASURE,
|
|
VoteClassification.vote_relationship
|
|
== VoteRelationship.DIRECT_TEXT_VOTE,
|
|
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
|
VoteClassification.is_substantive_policy_vote.is_(True),
|
|
VoteClassification.is_special_rule.is_(False),
|
|
)
|
|
)
|
|
)
|
|
|
|
summary_filters = [*bill_text_filters, _bill_text_has_summary_clause()]
|
|
|
|
bills_with_summaries = session.scalar(
|
|
select(func.count(func.distinct(Bill.id)))
|
|
.select_from(Bill)
|
|
.join(BillText, BillText.bill_id == Bill.id)
|
|
.where(*bill_filters, *summary_filters)
|
|
)
|
|
selected_bills = session.scalar(
|
|
select(func.count()).select_from(
|
|
create_select_bills_for_topic_extraction(
|
|
congress=congress,
|
|
bill_ids=bill_ids,
|
|
bill_text_ids=bill_text_ids,
|
|
with_votes_only=with_votes_only,
|
|
force=force,
|
|
limit=limit,
|
|
).subquery()
|
|
)
|
|
)
|
|
|
|
return TopicExtractionDiagnostics(
|
|
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
|
bill_text_rows=_count_bill_texts(
|
|
session,
|
|
bill_filters=bill_filters,
|
|
bill_text_filters=bill_text_filters,
|
|
),
|
|
summarized_bill_text_rows=_count_bill_texts(
|
|
session,
|
|
bill_filters=bill_filters,
|
|
bill_text_filters=summary_filters,
|
|
),
|
|
bills_with_summaries=bills_with_summaries or 0,
|
|
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
|
selected_bills=selected_bills or 0,
|
|
)
|
|
|
|
|
|
def _load_json_response(response_text: str) -> dict[str, Any]:
|
|
text = response_text.strip()
|
|
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
|
if fenced:
|
|
text = fenced.group(1).strip()
|
|
|
|
try:
|
|
payload = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
msg = f"Topic extraction response is not valid JSON: {exc}"
|
|
raise TopicExtractionError(msg) from exc
|
|
if not isinstance(payload, dict):
|
|
msg = "Topic extraction response must be a JSON object"
|
|
raise TopicExtractionError(msg)
|
|
return payload
|
|
|
|
|
|
def _parse_confidence(raw: Any) -> float | None:
|
|
if raw is None:
|
|
return None
|
|
try:
|
|
return float(raw)
|
|
except (TypeError, ValueError) as exc:
|
|
msg = f"Invalid confidence: {raw!r}"
|
|
raise TopicExtractionError(msg) from exc
|
|
|
|
|
|
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
|
if topic.confidence is None:
|
|
return (0, 0.0)
|
|
return (1, topic.confidence)
|
|
|
|
|
|
def _extract_topic_label(item: dict[str, Any]) -> str:
|
|
raw_topic = item.get("topic")
|
|
if isinstance(raw_topic, str):
|
|
return raw_topic
|
|
if isinstance(raw_topic, dict):
|
|
for key in ("topic", "label", "name", "title"):
|
|
value = raw_topic.get(key)
|
|
if isinstance(value, str):
|
|
return value
|
|
|
|
msg = "Topic extraction response topic must be a string"
|
|
raise TopicExtractionError(msg)
|
|
|
|
|
|
def _count_bill_texts(
|
|
session: Session,
|
|
*,
|
|
bill_filters: Sequence[ColumnElement[bool]],
|
|
bill_text_filters: Sequence[ColumnElement[bool]],
|
|
) -> int:
|
|
stmt = select(func.count(BillText.id))
|
|
if bill_filters:
|
|
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
|
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
|
|
|
|
|
def main(
|
|
topics_path: Annotated[
|
|
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
|
] = DEFAULT_TOPICS_PATH,
|
|
congress: Annotated[
|
|
int | None, typer.Option(help="Only process one Congress.")
|
|
] = None,
|
|
bill_ids: Annotated[
|
|
list[int] | None,
|
|
typer.Option(
|
|
"--bill-id",
|
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
|
),
|
|
] = None,
|
|
bill_text_ids: Annotated[
|
|
list[int] | None,
|
|
typer.Option(
|
|
"--bill-text-id",
|
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
|
),
|
|
] = None,
|
|
with_votes_only: Annotated[
|
|
bool,
|
|
typer.Option(
|
|
"--with-votes-only",
|
|
help="Only process summarized bill_text rows linked to at least one vote.",
|
|
),
|
|
] = True,
|
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
|
force: Annotated[
|
|
bool,
|
|
typer.Option(help="Regenerate topics for bills that already have topics."),
|
|
] = False,
|
|
dry_run: Annotated[
|
|
bool,
|
|
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
|
] = False,
|
|
diagnose: Annotated[
|
|
bool,
|
|
typer.Option(help="Log input-stage counts before processing."),
|
|
] = False,
|
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
|
) -> None:
|
|
"""CLI entrypoint for generating and storing bill topics."""
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
topic_catalog = load_topic_catalog(topics_path)
|
|
logger.info(
|
|
"Loaded %d candidate topics from %s",
|
|
len(topic_catalog.candidate_topics),
|
|
topics_path,
|
|
)
|
|
|
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
|
with Session(engine) as session:
|
|
if diagnose or dry_run:
|
|
diagnostics = collect_topic_extraction_diagnostics(
|
|
session,
|
|
congress=congress,
|
|
bill_ids=bill_ids,
|
|
bill_text_ids=bill_text_ids,
|
|
with_votes_only=with_votes_only,
|
|
force=force,
|
|
limit=limit,
|
|
)
|
|
_log_topic_extraction_diagnostics(diagnostics)
|
|
if dry_run:
|
|
return
|
|
|
|
openai_config = get_openai_config()
|
|
|
|
stmt = create_select_bills_for_topic_extraction(
|
|
congress=congress,
|
|
bill_ids=bill_ids,
|
|
bill_text_ids=bill_text_ids,
|
|
with_votes_only=with_votes_only,
|
|
force=force,
|
|
limit=limit,
|
|
)
|
|
bills = session.scalars(stmt).all()
|
|
logger.info("Selected %d bills for topic extraction", len(bills))
|
|
|
|
written = 0
|
|
failed = 0
|
|
for index, bill in enumerate(bills, 1):
|
|
bill_text = _select_bill_text_for_topic_extraction(bill)
|
|
if bill_text is None:
|
|
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
|
continue
|
|
summary_row = bill_text.default_summary()
|
|
if summary_row is None:
|
|
logger.warning("Skipping bill id=%s: no default summary", bill.id)
|
|
continue
|
|
summary = summary_row.summary.strip()
|
|
|
|
try:
|
|
extracted_topics = extract_topics_for_bill_text(
|
|
openai_config=openai_config,
|
|
bill=bill,
|
|
text=summary,
|
|
candidate_topics=topic_catalog.candidate_topics,
|
|
)
|
|
except (httpx.HTTPError, TopicExtractionError):
|
|
failed += 1
|
|
logger.exception(
|
|
"Skipping bill id=%s after topic extraction failure", bill.id
|
|
)
|
|
continue
|
|
|
|
store_bill_topic_result(
|
|
session=session,
|
|
bill=bill,
|
|
topics=extracted_topics,
|
|
replace_existing=True,
|
|
)
|
|
written += 1
|
|
if index % 100 == 0:
|
|
session.commit()
|
|
logger.info(
|
|
"Stored %d topics for bill id=%s",
|
|
len(extracted_topics),
|
|
bill.id,
|
|
)
|
|
|
|
session.commit()
|
|
logger.info(
|
|
"Done: stored topic results for %d bills; failed %d bills",
|
|
written,
|
|
failed,
|
|
)
|
|
|
|
|
|
def _log_topic_extraction_diagnostics(
|
|
diagnostics: TopicExtractionDiagnostics,
|
|
) -> None:
|
|
logger.info(
|
|
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
|
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
|
"bill_topic_rows=%d selected_bills=%d",
|
|
diagnostics.bill_rows,
|
|
diagnostics.bill_text_rows,
|
|
diagnostics.summarized_bill_text_rows,
|
|
diagnostics.bills_with_summaries,
|
|
diagnostics.bill_topic_rows,
|
|
diagnostics.selected_bills,
|
|
)
|
|
if diagnostics.bill_rows == 0:
|
|
logger.warning("No bills matched the topic extraction scope.")
|
|
elif diagnostics.bill_text_rows == 0:
|
|
logger.warning("No bill_text rows matched the topic extraction scope.")
|
|
elif diagnostics.summarized_bill_text_rows == 0:
|
|
logger.warning(
|
|
"No summarized bill_text rows matched the topic extraction scope. "
|
|
"Run pipelines.tools.summarize_bills first."
|
|
)
|
|
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
|
logger.warning(
|
|
"No bills selected because matching bills already have topics. "
|
|
"Use --force to regenerate them."
|
|
)
|
|
elif diagnostics.selected_bills == 0:
|
|
logger.warning("No bills selected for topic extraction.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
typer.run(main)
|