Compare commits

...

8 Commits

Author SHA1 Message Date
e0f88c126e New dockerfile because weird config issues 2026-04-20 00:07:16 -04:00
716bed5300 fixed path 2026-04-19 23:50:10 -04:00
1e9c2a6caa Add eval rouge metrics 2026-04-19 23:40:50 -04:00
db3583e7f2 removed richie dot files refrens 2026-04-14 20:22:02 -04:00
a2cb640481 renamed prompt_bench dir to tools 2026-04-14 18:23:00 -04:00
b8d64a5b19 moved prompt_bench 2026-04-14 18:18:31 -04:00
2abd61d3b1 added orm code 2026-04-14 18:17:35 -04:00
7979dc3328 added .gitignore 2026-04-14 17:58:14 -04:00
56 changed files with 1351 additions and 81 deletions

216
.gitignore vendored Normal file
View File

@@ -0,0 +1,216 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml

View File

@@ -0,0 +1,7 @@
"""ORM package exports."""
from pipelines.orm.data_science_dev.base import DataScienceDevBase
__all__ = [
"DataScienceDevBase",
]

51
pipelines/orm/common.py Normal file
View File

@@ -0,0 +1,51 @@
"""Shared ORM definitions."""
from __future__ import annotations
from os import getenv
from typing import cast
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv(f"{name}_DB")
host = getenv(f"{name}_HOST")
port = getenv(f"{name}_PORT")
username = getenv(f"{name}_USER")
password = getenv(f"{name}_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

View File

@@ -0,0 +1,15 @@
"""Data science dev database ORM exports."""
from __future__ import annotations
from pipelines.orm.data_science_dev.base import (
DataScienceDevBase,
DataScienceDevTableBase,
DataScienceDevTableBaseBig,
)
__all__ = [
"DataScienceDevBase",
"DataScienceDevTableBase",
"DataScienceDevTableBaseBig",
]

View File

@@ -0,0 +1,52 @@
"""Data science dev database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from pipelines.orm.common import NAMING_CONVENTION
class DataScienceDevBase(DeclarativeBase):
"""Base class for data_science_dev database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class _TableMixin:
"""Shared timestamp columns for all table bases."""
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with Integer primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with BigInteger primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)

View File

@@ -0,0 +1,17 @@
"""init."""
from pipelines.orm.data_science_dev.congress.bill import Bill, BillText
from pipelines.orm.data_science_dev.congress.legislator import (
Legislator,
LegislatorSocialMedia,
)
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
__all__ = [
"Bill",
"BillText",
"Legislator",
"LegislatorSocialMedia",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,72 @@
"""Bill model - legislation introduced in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.vote import Vote
class Bill(DataScienceDevTableBase):
"""Legislation with congress number, type, titles, status, and sponsor."""
__tablename__ = "bill"
congress: Mapped[int]
bill_type: Mapped[str]
number: Mapped[int]
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
status: Mapped[str | None]
status_at: Mapped[date | None]
sponsor_bioguide_id: Mapped[str | None]
subjects_top_term: Mapped[str | None]
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
bill_texts: Mapped[list[BillText]] = relationship(
"BillText",
back_populates="bill",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
),
Index("ix_bill_congress", "congress"),
)
class BillText(DataScienceDevTableBase):
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
__tablename__ = "bill_text"
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
version_code: Mapped[str]
version_name: Mapped[str | None]
text_content: Mapped[str | None]
date: Mapped[date | None]
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
__table_args__ = (
UniqueConstraint(
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
),
)

View File

@@ -0,0 +1,68 @@
"""Legislator model - members of Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.vote import VoteRecord
class Legislator(DataScienceDevTableBase):
"""Members of Congress with identification and current term info."""
__tablename__ = "legislator"
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None]
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
birthday: Mapped[date | None]
gender: Mapped[str | None]
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None]
current_chamber: Mapped[str | None]
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
"LegislatorSocialMedia",
back_populates="legislator",
cascade="all, delete-orphan",
)
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class LegislatorSocialMedia(DataScienceDevTableBase):
"""Social media account linked to a legislator."""
__tablename__ = "legislator_social_media"
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
platform: Mapped[str]
account_name: Mapped[str]
url: Mapped[str | None]
source: Mapped[str]
legislator: Mapped[Legislator] = relationship(
back_populates="social_media_accounts"
)

View File

@@ -0,0 +1,84 @@
"""Vote model - roll call votes in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import (
DataScienceDevBase,
DataScienceDevTableBase,
)
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.bill import Bill
from pipelines.orm.data_science_dev.congress.legislator import Legislator
from pipelines.orm.data_science_dev.congress.vote import Vote
class VoteRecord(DataScienceDevBase):
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str]
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship(
"Legislator", back_populates="vote_records"
)
class Vote(DataScienceDevTableBase):
"""Roll call votes with counts and optional bill linkage."""
__tablename__ = "vote"
congress: Mapped[int]
chamber: Mapped[str]
session: Mapped[int]
number: Mapped[int]
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
vote_date: Mapped[date]
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress",
"chamber",
"session",
"number",
name="uq_vote_congress_chamber_session_number",
),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)

View File

@@ -0,0 +1,16 @@
"""Data science dev database ORM models."""
from __future__ import annotations
from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
from pipelines.orm.data_science_dev.posts.tables import Posts
__all__ = [
"Bill",
"BillText",
"Legislator",
"Posts",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,11 @@
"""Posts module — weekly-partitioned posts table and partition ORM models."""
from __future__ import annotations
from pipelines.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
from pipelines.orm.data_science_dev.posts.tables import Posts
__all__ = [
"FailedIngestion",
"Posts",
]

View File

@@ -0,0 +1,33 @@
"""Shared column definitions for the posts partitioned table family."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, SmallInteger, Text
from sqlalchemy.orm import Mapped, mapped_column
class PostsColumns:
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger)
instance: Mapped[str]
date: Mapped[datetime] = mapped_column(primary_key=True)
text: Mapped[str] = mapped_column(Text)
langs: Mapped[str | None]
like_count: Mapped[int]
reply_count: Mapped[int]
repost_count: Mapped[int]
reply_to: Mapped[int | None] = mapped_column(BigInteger)
replied_author: Mapped[int | None] = mapped_column(BigInteger)
thread_root: Mapped[int | None] = mapped_column(BigInteger)
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
repost_from: Mapped[int | None] = mapped_column(BigInteger)
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
quotes: Mapped[int | None] = mapped_column(BigInteger)
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
labels: Mapped[str | None]
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
sent_score: Mapped[float | None]

View File

@@ -0,0 +1,17 @@
"""Table for storing JSONL lines that failed during post ingestion."""
from __future__ import annotations
from sqlalchemy import Text
from sqlalchemy.orm import Mapped, mapped_column
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
class FailedIngestion(DataScienceDevTableBase):
"""Stores raw JSONL lines and their error messages when ingestion fails."""
__tablename__ = "failed_ingestion"
raw_line: Mapped[str] = mapped_column(Text)
error: Mapped[str] = mapped_column(Text)

View File

@@ -0,0 +1,71 @@
"""Dynamically generated ORM classes for each weekly partition of the posts table.
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
These are real ORM models tracked by Alembic autogenerate.
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
52 or 53 weeks, and week boundaries are always Monday to Monday.
"""
from __future__ import annotations
import sys
from datetime import UTC, datetime
from pipelines.orm.data_science_dev.base import DataScienceDevBase
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
PARTITION_START_YEAR = 2023
PARTITION_END_YEAR = 2026
_current_module = sys.modules[__name__]
def iso_weeks_in_year(year: int) -> int:
"""Return the number of ISO weeks in a given year (52 or 53)."""
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
return dec_28.isocalendar().week
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
"""Return (start, end) datetimes for an ISO week.
Start = Monday 00:00:00 UTC of the given ISO week.
End = Monday 00:00:00 UTC of the following ISO week.
"""
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
if week < iso_weeks_in_year(year):
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
else:
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
return start, end
def _build_partition_classes() -> dict[str, type]:
"""Generate one ORM class per ISO week partition."""
classes: dict[str, type] = {}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
class_name = f"PostsWeek{year}W{week:02d}"
table_name = f"posts_{year}_{week:02d}"
partition_class = type(
class_name,
(PostsColumns, DataScienceDevBase),
{
"__tablename__": table_name,
"__table_args__": ({"implicit_returning": False},),
},
)
classes[class_name] = partition_class
return classes
# Generate all partition classes and register them on this module
_partition_classes = _build_partition_classes()
for _name, _cls in _partition_classes.items():
setattr(_current_module, _name, _cls)
__all__ = list(_partition_classes.keys())

View File

@@ -0,0 +1,13 @@
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
from __future__ import annotations
from pipelines.orm.data_science_dev.base import DataScienceDevBase
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
class Posts(PostsColumns, DataScienceDevBase):
"""Parent partitioned table for posts, partitioned by week on `date`."""
__tablename__ = "posts"
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)

View File

@@ -0,0 +1,26 @@
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
#
# Build:
# docker build -f pipelines/pipelines/tools/Dockerfile.finetune -t bill-finetune .
#
# Run:
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
# -v $(pwd)/output:/workspace/output \
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
# -v /zfs/models/hf:/models \
# bill-finetune \
# --dataset /workspace/dataset.jsonl \
# --output-dir /workspace/output/qwen-bill-summarizer
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer rouge-score
WORKDIR /workspace
COPY pipelines/tools/__init__.py pipelines/tools/__init__.py
COPY pipelines/tools/finetune.py pipelines/tools/finetune.py
COPY pipelines/tools/summarization_eval.py pipelines/tools/summarization_eval.py
COPY summarization_prompts.toml config/prompts/summarization_prompts.toml
COPY config.toml pipelines/tools/config.toml
ENTRYPOINT ["python", "-m", "pipelines.tools.finetune"]

View File

@@ -23,9 +23,14 @@ import httpx
import typer import typer
from tiktoken import Encoding, get_encoding from tiktoken import Encoding, get_encoding
from python.prompt_bench.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml" _PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
@@ -72,7 +77,12 @@ def build_request(custom_id: str, model: str, bill_text: str) -> dict:
"model": model, "model": model,
"messages": [ "messages": [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT}, {"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)}, {
"role": "user",
"content": SUMMARIZATION_USER_TEMPLATE.format(
text_content=bill_text
),
},
], ],
}, },
} }
@@ -123,7 +133,9 @@ def prepare_requests(
"compressed_chars": len(compressed_text), "compressed_chars": len(compressed_text),
"raw_tokens": raw_token_count, "raw_tokens": raw_token_count,
"compressed_tokens": compressed_token_count, "compressed_tokens": compressed_token_count,
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None, "token_ratio": (compressed_token_count / raw_token_count)
if raw_token_count
else None,
}, },
) )
safe_id = safe_filename(bill_id) safe_id = safe_filename(bill_id)
@@ -136,7 +148,14 @@ def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
with path.open("w", newline="", encoding="utf-8") as handle: with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter( writer = csv.DictWriter(
handle, handle,
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"], fieldnames=[
"bill_id",
"raw_chars",
"compressed_chars",
"raw_tokens",
"compressed_tokens",
"token_ratio",
],
) )
writer.writeheader() writer.writeheader()
writer.writerows(token_rows) writer.writerows(token_rows)
@@ -161,8 +180,12 @@ def create_batch(client: httpx.Client, input_file_id: str, description: str) ->
def main( def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"), csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path( "bills.csv"
),
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")
] = Path(
"output/openai_batch", "output/openai_batch",
), ),
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini", model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
@@ -170,7 +193,9 @@ def main(
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Submit an OpenAI Batch job of compressed bill summaries.""" """Submit an OpenAI Batch job of compressed bill summaries."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY") api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
@@ -191,7 +216,9 @@ def main(
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder) request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
token_csv_path = output_dir / "token_counts.csv" token_csv_path = output_dir / "token_counts.csv"
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows) raw_tokens_total, compressed_tokens_total = write_token_csv(
token_csv_path, token_rows
)
logger.info( logger.info(
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s", "Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
raw_tokens_total, raw_tokens_total,
@@ -211,7 +238,11 @@ def main(
logger.info("Uploaded: %s", file_id) logger.info("Uploaded: %s", file_id)
logger.info("Creating batch") logger.info("Creating batch")
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})") batch = create_batch(
client,
file_id,
f"compressed bill summaries x{len(request_lines)} ({model})",
)
logger.info("Batch created: %s", batch["id"]) logger.info("Batch created: %s", batch["id"])
metadata = { metadata = {

View File

@@ -24,9 +24,14 @@ from typing import Annotated
import httpx import httpx
import typer import typer
from python.prompt_bench.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml" _PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
@@ -62,7 +67,10 @@ def build_messages(bill_text: str) -> list[dict]:
"""Return the system + user message pair for a bill.""" """Return the system + user message pair for a bill."""
return [ return [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT}, {"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)}, {
"role": "user",
"content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text),
},
] ]
@@ -132,17 +140,25 @@ def run_one_request(
def main( def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"), csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path( "bills.csv"
),
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to write per-request JSON")
] = Path(
"output/openai_runs", "output/openai_runs",
), ),
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL, model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT, count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16, concurrency: Annotated[
int, typer.Option(help="Concurrent in-flight requests")
] = 16,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text.""" """Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY") api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
@@ -165,8 +181,17 @@ def main(
tasks: list[tuple[str, str, str, Path]] = [] tasks: list[tuple[str, str, str, Path]] = []
for bill_id, text_content in bills: for bill_id, text_content in bills:
filename = f"{safe_filename(bill_id)}.json" filename = f"{safe_filename(bill_id)}.json"
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename)) tasks.append(
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename)) (
bill_id,
"compressed",
compress_bill_text(text_content),
compressed_dir / filename,
)
)
tasks.append(
(bill_id, "uncompressed", text_content, uncompressed_dir / filename)
)
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency) logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)

View File

@@ -9,13 +9,13 @@ from typing import Annotated
import typer import typer
from python.prompt_bench.containers.lib import check_gpu_free from pipelines.tools.containers.lib import check_gpu_free
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CONTAINER_NAME = "bill-finetune" CONTAINER_NAME = "bill-finetune"
FINETUNE_IMAGE = "bill-finetune:latest" FINETUNE_IMAGE = "bill-finetune:latest"
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune" REPO_DIR = Path(__file__).resolve().parents[4]
DEFAULT_HF_CACHE = Path("/zfs/models/hf") DEFAULT_HF_CACHE = Path("/zfs/models/hf")
@@ -23,7 +23,15 @@ def build_image() -> None:
"""Build the fine-tuning Docker image.""" """Build the fine-tuning Docker image."""
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE) logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
result = subprocess.run( result = subprocess.run(
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."], [
"docker",
"build",
"-f",
str(REPO_DIR / "pipelines/pipelines/tools/Dockerfile.finetune"),
"-t",
FINETUNE_IMAGE,
".",
],
text=True, text=True,
check=False, check=False,
) )
@@ -95,7 +103,9 @@ def stop_finetune() -> None:
"""Stop and remove the fine-tuning container.""" """Stop and remove the fine-tuning container."""
logger.info("Stopping fine-tuning container") logger.info("Stopping fine-tuning container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False) subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False) subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False
)
def logs_finetune() -> str | None: def logs_finetune() -> str | None:
@@ -122,17 +132,20 @@ def build() -> None:
@app.command() @app.command()
def run( def run(
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path( dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
"/home/richie/dotfiles/data/finetune_dataset.jsonl" / "/zfs/storage/data_science/data/finetune_dataset.jsonl",
), output_dir: Annotated[
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path( Path, typer.Option(help="Where to save the trained model")
"/home/richie/dotfiles/data/output/qwen-bill-summarizer", ] = REPO_DIR / "data/output/qwen-bill-summarizer",
), hf_cache: Annotated[
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE, Path, typer.Option(help="Host path to HuggingFace model cache")
] = DEFAULT_HF_CACHE,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run fine-tuning inside a Docker container.""" """Run fine-tuning inside a Docker container."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
check_gpu_free() check_gpu_free()
start_finetune( start_finetune(
dataset_path=dataset, dataset_path=dataset,
@@ -140,6 +153,7 @@ def run(
hf_cache=hf_cache, hf_cache=hf_cache,
) )
@app.command() @app.command()
def stop() -> None: def stop() -> None:
"""Stop and remove the fine-tuning container.""" """Stop and remove the fine-tuning container."""

View File

@@ -9,7 +9,7 @@ from typing import Annotated
import typer import typer
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from python.prompt_bench.models import BenchmarkConfig from pipelines.tools.models import BenchmarkConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,11 +52,15 @@ def download_all(config: BenchmarkConfig) -> None:
def main( def main(
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"), config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
"bench.toml"
),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Download all models listed in the benchmark config.""" """Download all models listed in the benchmark config."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not config.is_file(): if not config.is_file():
message = f"Config file does not exist: {config}" message = f"Config file does not exist: {config}"

View File

@@ -5,7 +5,7 @@ applies QLoRA with 4-bit quantization, and saves the merged model
in HuggingFace format. Designed for a single RTX 3090 (24GB). in HuggingFace format. Designed for a single RTX 3090 (24GB).
Usage: Usage:
python -m python.prompt_bench.finetune \ python -m pipelines.prompt_bench.finetune \
--dataset output/finetune_dataset.jsonl \ --dataset output/finetune_dataset.jsonl \
--output-dir output/qwen-bill-summarizer --output-dir output/qwen-bill-summarizer
""" """
@@ -25,6 +25,8 @@ from datasets import Dataset
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import SFTTrainer from trl import SFTTrainer
from .summarization_eval import make_compute_metrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -107,21 +109,31 @@ def load_dataset_from_jsonl(path: Path) -> Dataset:
def main( def main(
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path( dataset_path: Annotated[
Path, typer.Option("--dataset", help="Fine-tuning JSONL")
] = Path(
"output/finetune_dataset.jsonl", "output/finetune_dataset.jsonl",
), ),
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1, validation_split: Annotated[
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path( float, typer.Option("--val-split", help="Fraction held out for validation")
] = 0.1,
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to save the merged model")
] = Path(
"output/qwen-bill-summarizer", "output/qwen-bill-summarizer",
), ),
config_path: Annotated[ config_path: Annotated[
Path, Path,
typer.Option("--config", help="TOML config file"), typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml", ] = Path(__file__).parent / "config.toml",
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False, save_gguf: Annotated[
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
] = False,
) -> None: ) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA.""" """Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not dataset_path.is_file(): if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}" message = f"Dataset not found: {dataset_path}"
@@ -137,7 +149,9 @@ def main(
dtype=None, dtype=None,
) )
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha) logger.info(
"Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha
)
model = FastLanguageModel.get_peft_model( model = FastLanguageModel.get_peft_model(
model, model,
r=config.lora.rank, r=config.lora.rank,
@@ -153,7 +167,9 @@ def main(
split = full_dataset.train_test_split(test_size=validation_split, seed=42) split = full_dataset.train_test_split(test_size=validation_split, seed=42)
train_dataset = split["train"] train_dataset = split["train"]
validation_dataset = split["test"] validation_dataset = split["test"]
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset)) logger.info(
"Split: %d train, %d validation", len(train_dataset), len(validation_dataset)
)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"), output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.training.epochs, num_train_epochs=config.training.epochs,
@@ -173,6 +189,9 @@ def main(
optim="adamw_8bit", optim="adamw_8bit",
seed=42, seed=42,
report_to="none", report_to="none",
metric_for_best_model="eval_composite",
greater_is_better=True,
predict_with_generate=True,
) )
trainer = SFTTrainer( trainer = SFTTrainer(
@@ -183,6 +202,7 @@ def main(
args=training_args, args=training_args,
max_seq_length=config.training.max_seq_length, max_seq_length=config.training.max_seq_length,
packing=True, packing=True,
compute_metrics=make_compute_metrics(tokenizer),
) )
logger.info( logger.info(

View File

@@ -11,11 +11,11 @@ from typing import Annotated
import typer import typer
from python.prompt_bench.containers.lib import check_gpu_free from pipelines.tools.containers.lib import check_gpu_free
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm from pipelines.tools.containers.vllm import start_vllm, stop_vllm
from python.prompt_bench.downloader import is_model_present from pipelines.tools.downloader import is_model_present
from python.prompt_bench.models import BenchmarkConfig from pipelines.tools.models import BenchmarkConfig
from python.prompt_bench.vllm_client import VLLMClient from pipelines.tools.vllm_client import VLLMClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,7 +72,9 @@ def benchmark_model(
vLLM batches concurrent requests internally, so submitting many at once is vLLM batches concurrent requests internally, so submitting many at once is
significantly faster than running them serially. significantly faster than running them serially.
""" """
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()] pending = [
prompt for prompt in prompts if not (model_output / prompt.name).exists()
]
skipped = len(prompts) - len(pending) skipped = len(prompts) - len(pending)
if skipped: if skipped:
logger.info("Skipping %d prompts with existing output for %s", skipped, repo) logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
@@ -185,13 +187,21 @@ def run_benchmark(
def main( def main(
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")], input_dir: Annotated[
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"), Path, typer.Argument(help="Directory containing input .txt prompt files")
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"), ],
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
"bench.toml"
),
output_dir: Annotated[
Path, typer.Option(help="Output directory for results")
] = Path("output"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run prompts through multiple LLMs via vLLM and save results.""" """Run prompts through multiple LLMs via vLLM and save results."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not input_dir.is_dir(): if not input_dir.is_dir():
message = f"Input directory does not exist: {input_dir}" message = f"Input directory does not exist: {input_dir}"

View File

@@ -0,0 +1,426 @@
"""Summarization evaluation for Congressional bill summaries.
Three use cases from one module:
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
from summarization_eval import filter_dataset
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
from summarization_eval import make_compute_metrics
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
3. Inference eval — score a finished model against held-out references:
from summarization_eval import evaluate_file
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import numpy as np
from rouge_score import rouge_scorer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SECTION_HEADERS = [
"OPERATIVE ACTIONS",
"AFFECTED POPULATIONS",
"MECHANISMS",
"POLICY THREADS",
"SYMBOLIC/PROCEDURAL ONLY",
]
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
ROUGE_WEIGHTS = {
"rouge1": 0.2,
"rouge2": 0.4,
"rougeL": 0.4,
}
# Composite score floor below which a training example is considered low quality
FILTER_THRESHOLD = 0.25
# ---------------------------------------------------------------------------
# Core data structures
# ---------------------------------------------------------------------------
@dataclass
class SummaryScore:
"""Scores for a single (prediction, reference) pair."""
rouge1: float
rouge2: float
rougeL: float
composite: float
has_all_sections: bool # True = all 5 headers present
missing_sections: list[str]
structural_fail: bool # True = one or more headers missing (hard guardrail)
def as_dict(self) -> dict:
return {
"rouge1": self.rouge1,
"rouge2": self.rouge2,
"rougeL": self.rougeL,
"composite": self.composite,
"has_all_sections": self.has_all_sections,
"missing_sections": self.missing_sections,
"structural_fail": self.structural_fail,
}
@dataclass
class BatchResult:
"""Aggregate results over a batch of summaries."""
n_total: int
n_structural_fail: int
n_scored: int # excludes structural failures
rouge1_mean: float
rouge2_mean: float
rougeL_mean: float
composite_mean: float
scores: list[SummaryScore]
def as_dict(self) -> dict:
return {
"n_total": self.n_total,
"n_structural_fail": self.n_structural_fail,
"n_scored": self.n_scored,
"rouge1_mean": self.rouge1_mean,
"rouge2_mean": self.rouge2_mean,
"rougeL_mean": self.rougeL_mean,
"composite_mean": self.composite_mean,
}
# ---------------------------------------------------------------------------
# Core scoring
# ---------------------------------------------------------------------------
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
def check_sections(text: str) -> tuple[bool, list[str]]:
"""Return (all_present, missing_headers) for the 5 required section headers."""
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
return len(missing) == 0, missing
def score_pair(prediction: str, reference: str) -> SummaryScore:
"""Score a single (prediction, reference) pair.
If the prediction is missing any section header, structural_fail is True
and ROUGE scores are still computed (so you can inspect quality even on
structural failures) but the example should be treated as a guardrail failure.
"""
has_all, missing = check_sections(prediction)
rouge = _scorer.score(reference, prediction)
r1 = rouge["rouge1"].fmeasure
r2 = rouge["rouge2"].fmeasure
rl = rouge["rougeL"].fmeasure
composite = (
ROUGE_WEIGHTS["rouge1"] * r1
+ ROUGE_WEIGHTS["rouge2"] * r2
+ ROUGE_WEIGHTS["rougeL"] * rl
)
return SummaryScore(
rouge1=r1,
rouge2=r2,
rougeL=rl,
composite=composite,
has_all_sections=has_all,
missing_sections=missing,
structural_fail=not has_all,
)
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
"""Score a list of (prediction, reference) pairs and return aggregate results.
Structural failures are counted separately and excluded from ROUGE means
so a batch with broken formatting doesn't drag down the score unfairly.
"""
scores = [score_pair(pred, ref) for pred, ref in pairs]
structural_fails = [s for s in scores if s.structural_fail]
valid = [s for s in scores if not s.structural_fail]
if valid:
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
composite_mean = float(np.mean([s.composite for s in valid]))
else:
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
return BatchResult(
n_total=len(scores),
n_structural_fail=len(structural_fails),
n_scored=len(valid),
rouge1_mean=rouge1_mean,
rouge2_mean=rouge2_mean,
rougeL_mean=rougeL_mean,
composite_mean=composite_mean,
scores=scores,
)
# ---------------------------------------------------------------------------
# Use case 1: Data filtering
# ---------------------------------------------------------------------------
def filter_dataset(
input_path: Path | str,
output_path: Path | str,
*,
threshold: float = FILTER_THRESHOLD,
) -> tuple[int, int]:
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
Each line must be a ChatML messages dict:
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
The assistant turn is the prediction. The reference is the same assistant
turn — filtering here uses composite score as a self-consistency check
against the threshold, and drops structural failures unconditionally.
In practice you'd call this after joining requests + GPT completions
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
or suspiciously short/low quality.
Returns (kept, dropped).
"""
input_path = Path(input_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
kept = 0
dropped = 0
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
for line_num, raw_line in enumerate(fin, 1):
stripped = raw_line.strip()
if not stripped:
continue
example = json.loads(stripped)
messages = example.get("messages", [])
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
if not assistant_turns:
logger.warning("Line %d: no assistant turn, dropping", line_num)
dropped += 1
continue
prediction = assistant_turns[-1].get("content", "")
# Guardrail: drop if any section header missing
has_all, missing = check_sections(prediction)
if not has_all:
logger.warning(
"Line %d: structural fail (missing: %s), dropping",
line_num,
", ".join(missing),
)
dropped += 1
continue
# Quality floor: score against itself isn't meaningful for filtering —
# instead just check composite score of prediction vs a simple
# word-count proxy. For filtering GPT outputs, structural check
# + a minimum word count is usually sufficient.
word_count = len(prediction.split())
if word_count < 80:
logger.warning(
"Line %d: too short (%d words), dropping", line_num, word_count
)
dropped += 1
continue
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
kept += 1
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
return kept, dropped
# ---------------------------------------------------------------------------
# Use case 2: compute_metrics hook for SFTTrainer
# ---------------------------------------------------------------------------
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
"""Return a compute_metrics function compatible with HuggingFace Trainer.
Usage in finetune.py:
from summarization_eval import make_compute_metrics
trainer = SFTTrainer(
...
compute_metrics=make_compute_metrics(tokenizer),
)
Note: EvalPrediction.predictions are logits (or token ids if
include_inputs_for_metrics is False). This function handles both.
For SFTTrainer with packing=True, you may need to set
predict_with_generate=True in TrainingArguments to get decoded text.
"""
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
predictions, labels = eval_pred
# If predictions are logits, take argmax
if predictions.ndim == 3:
predictions = np.argmax(predictions, axis=-1)
# Mask out -100 padding in labels
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
pairs = list(zip(decoded_preds, decoded_labels))
result = score_batch(pairs)
metrics = {
"eval_rouge1": result.rouge1_mean,
"eval_rouge2": result.rouge2_mean,
"eval_rougeL": result.rougeL_mean,
"eval_composite": result.composite_mean,
"eval_structural_fail_rate": (
result.n_structural_fail / result.n_total if result.n_total else 0.0
),
}
logger.info(
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
metrics["eval_composite"],
metrics["eval_rouge1"],
metrics["eval_rouge2"],
metrics["eval_rougeL"],
result.n_structural_fail,
result.n_total,
)
return metrics
return compute_metrics
# ---------------------------------------------------------------------------
# Use case 3: Inference eval against held-out references
# ---------------------------------------------------------------------------
def evaluate_file(
predictions_path: Path | str,
references_path: Path | str,
output_path: Path | str | None = None,
) -> BatchResult:
"""Score a predictions JSONL against a references JSONL.
Both files should be line-matched: line N of predictions corresponds
to line N of references. Each line should be a plain JSON object with
a "text" or "content" key, or a ChatML messages dict.
If output_path is provided, writes per-example scores as JSONL.
"""
predictions_path = Path(predictions_path)
references_path = Path(references_path)
def extract_text(line: str) -> str:
obj = json.loads(line)
# Plain text field
if "text" in obj:
return obj["text"]
if "content" in obj:
return obj["content"]
# ChatML messages — take last assistant turn
messages = obj.get("messages", [])
for m in reversed(messages):
if m.get("role") == "assistant":
return m.get("content", "")
return ""
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
if len(preds) != len(refs):
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
raise ValueError(msg)
result = score_batch(list(zip(preds, refs)))
logger.info(
"Inference eval: n=%d structural_fails=%d composite=%.4f "
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
result.n_total,
result.n_structural_fail,
result.composite_mean,
result.rouge1_mean,
result.rouge2_mean,
result.rougeL_mean,
)
if output_path is not None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as fout:
for score in result.scores:
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
summary_path = output_path.with_suffix(".summary.json")
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
logger.info("Wrote per-example scores to %s", output_path)
logger.info("Wrote summary to %s", summary_path)
return result
# ---------------------------------------------------------------------------
# CLI — quick sanity check / standalone use
# ---------------------------------------------------------------------------
def _cli() -> None:
import argparse
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
subparsers = parser.add_subparsers(dest="command", required=True)
# filter subcommand
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
fp.add_argument("--input", required=True, type=Path)
fp.add_argument("--output", required=True, type=Path)
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
# eval subcommand
ep = subparsers.add_parser("eval", help="Score predictions against references")
ep.add_argument("--predictions", required=True, type=Path)
ep.add_argument("--references", required=True, type=Path)
ep.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
if args.command == "filter":
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
print(f"Kept: {kept} Dropped: {dropped}")
elif args.command == "eval":
result = evaluate_file(args.predictions, args.references, args.output)
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
print(f" Composite: {result.composite_mean:.4f}")
if __name__ == "__main__":
_cli()

View File

@@ -1,25 +0,0 @@
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
#
# Build:
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
#
# Run:
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
# -v $(pwd)/output:/workspace/output \
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
# -v /zfs/models/hf:/models \
# bill-finetune \
# --dataset /workspace/dataset.jsonl \
# --output-dir /workspace/output/qwen-bill-summarizer
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer
WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"]

View File

@@ -1 +0,0 @@
how many oceans are there in the world

View File

@@ -1 +0,0 @@
whos the president of the united states

View File

@@ -1 +0,0 @@
whats the greatest country in the world

View File

@@ -1 +0,0 @@
was/is the usa the greatest country in the world