Compare commits

..

5 Commits

Author SHA1 Message Date
db3583e7f2 removed richie dot files refrens 2026-04-14 20:22:02 -04:00
a2cb640481 renamed prompt_bench dir to tools 2026-04-14 18:23:00 -04:00
b8d64a5b19 moved prompt_bench 2026-04-14 18:18:31 -04:00
2abd61d3b1 added orm code 2026-04-14 18:17:35 -04:00
7979dc3328 added .gitignore 2026-04-14 17:58:14 -04:00
54 changed files with 894 additions and 57 deletions

216
.gitignore vendored Normal file
View File

@@ -0,0 +1,216 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml

View File

@@ -0,0 +1,7 @@
"""ORM package exports."""
from pipelines.orm.data_science_dev.base import DataScienceDevBase
__all__ = [
"DataScienceDevBase",
]

51
pipelines/orm/common.py Normal file
View File

@@ -0,0 +1,51 @@
"""Shared ORM definitions."""
from __future__ import annotations
from os import getenv
from typing import cast
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv(f"{name}_DB")
host = getenv(f"{name}_HOST")
port = getenv(f"{name}_PORT")
username = getenv(f"{name}_USER")
password = getenv(f"{name}_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

View File

@@ -0,0 +1,15 @@
"""Data science dev database ORM exports."""
from __future__ import annotations
from pipelines.orm.data_science_dev.base import (
DataScienceDevBase,
DataScienceDevTableBase,
DataScienceDevTableBaseBig,
)
__all__ = [
"DataScienceDevBase",
"DataScienceDevTableBase",
"DataScienceDevTableBaseBig",
]

View File

@@ -0,0 +1,52 @@
"""Data science dev database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from pipelines.orm.common import NAMING_CONVENTION
class DataScienceDevBase(DeclarativeBase):
"""Base class for data_science_dev database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class _TableMixin:
"""Shared timestamp columns for all table bases."""
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with Integer primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with BigInteger primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)

View File

@@ -0,0 +1,17 @@
"""init."""
from pipelines.orm.data_science_dev.congress.bill import Bill, BillText
from pipelines.orm.data_science_dev.congress.legislator import (
Legislator,
LegislatorSocialMedia,
)
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
__all__ = [
"Bill",
"BillText",
"Legislator",
"LegislatorSocialMedia",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,72 @@
"""Bill model - legislation introduced in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.vote import Vote
class Bill(DataScienceDevTableBase):
"""Legislation with congress number, type, titles, status, and sponsor."""
__tablename__ = "bill"
congress: Mapped[int]
bill_type: Mapped[str]
number: Mapped[int]
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
status: Mapped[str | None]
status_at: Mapped[date | None]
sponsor_bioguide_id: Mapped[str | None]
subjects_top_term: Mapped[str | None]
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
bill_texts: Mapped[list[BillText]] = relationship(
"BillText",
back_populates="bill",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
),
Index("ix_bill_congress", "congress"),
)
class BillText(DataScienceDevTableBase):
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
__tablename__ = "bill_text"
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
version_code: Mapped[str]
version_name: Mapped[str | None]
text_content: Mapped[str | None]
date: Mapped[date | None]
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
__table_args__ = (
UniqueConstraint(
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
),
)

View File

@@ -0,0 +1,68 @@
"""Legislator model - members of Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.vote import VoteRecord
class Legislator(DataScienceDevTableBase):
"""Members of Congress with identification and current term info."""
__tablename__ = "legislator"
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None]
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
birthday: Mapped[date | None]
gender: Mapped[str | None]
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None]
current_chamber: Mapped[str | None]
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
"LegislatorSocialMedia",
back_populates="legislator",
cascade="all, delete-orphan",
)
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class LegislatorSocialMedia(DataScienceDevTableBase):
"""Social media account linked to a legislator."""
__tablename__ = "legislator_social_media"
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
platform: Mapped[str]
account_name: Mapped[str]
url: Mapped[str | None]
source: Mapped[str]
legislator: Mapped[Legislator] = relationship(
back_populates="social_media_accounts"
)

View File

@@ -0,0 +1,84 @@
"""Vote model - roll call votes in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from pipelines.orm.data_science_dev.base import (
DataScienceDevBase,
DataScienceDevTableBase,
)
if TYPE_CHECKING:
from pipelines.orm.data_science_dev.congress.bill import Bill
from pipelines.orm.data_science_dev.congress.legislator import Legislator
from pipelines.orm.data_science_dev.congress.vote import Vote
class VoteRecord(DataScienceDevBase):
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str]
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship(
"Legislator", back_populates="vote_records"
)
class Vote(DataScienceDevTableBase):
"""Roll call votes with counts and optional bill linkage."""
__tablename__ = "vote"
congress: Mapped[int]
chamber: Mapped[str]
session: Mapped[int]
number: Mapped[int]
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
vote_date: Mapped[date]
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress",
"chamber",
"session",
"number",
name="uq_vote_congress_chamber_session_number",
),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)

View File

@@ -0,0 +1,16 @@
"""Data science dev database ORM models."""
from __future__ import annotations
from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
from pipelines.orm.data_science_dev.posts.tables import Posts
__all__ = [
"Bill",
"BillText",
"Legislator",
"Posts",
"Vote",
"VoteRecord",
]

View File

@@ -0,0 +1,11 @@
"""Posts module — weekly-partitioned posts table and partition ORM models."""
from __future__ import annotations
from pipelines.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
from pipelines.orm.data_science_dev.posts.tables import Posts
__all__ = [
"FailedIngestion",
"Posts",
]

View File

@@ -0,0 +1,33 @@
"""Shared column definitions for the posts partitioned table family."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, SmallInteger, Text
from sqlalchemy.orm import Mapped, mapped_column
class PostsColumns:
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger)
instance: Mapped[str]
date: Mapped[datetime] = mapped_column(primary_key=True)
text: Mapped[str] = mapped_column(Text)
langs: Mapped[str | None]
like_count: Mapped[int]
reply_count: Mapped[int]
repost_count: Mapped[int]
reply_to: Mapped[int | None] = mapped_column(BigInteger)
replied_author: Mapped[int | None] = mapped_column(BigInteger)
thread_root: Mapped[int | None] = mapped_column(BigInteger)
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
repost_from: Mapped[int | None] = mapped_column(BigInteger)
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
quotes: Mapped[int | None] = mapped_column(BigInteger)
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
labels: Mapped[str | None]
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
sent_score: Mapped[float | None]

View File

@@ -0,0 +1,17 @@
"""Table for storing JSONL lines that failed during post ingestion."""
from __future__ import annotations
from sqlalchemy import Text
from sqlalchemy.orm import Mapped, mapped_column
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
class FailedIngestion(DataScienceDevTableBase):
"""Stores raw JSONL lines and their error messages when ingestion fails."""
__tablename__ = "failed_ingestion"
raw_line: Mapped[str] = mapped_column(Text)
error: Mapped[str] = mapped_column(Text)

View File

@@ -0,0 +1,71 @@
"""Dynamically generated ORM classes for each weekly partition of the posts table.
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
These are real ORM models tracked by Alembic autogenerate.
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
52 or 53 weeks, and week boundaries are always Monday to Monday.
"""
from __future__ import annotations
import sys
from datetime import UTC, datetime
from pipelines.orm.data_science_dev.base import DataScienceDevBase
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
PARTITION_START_YEAR = 2023
PARTITION_END_YEAR = 2026
_current_module = sys.modules[__name__]
def iso_weeks_in_year(year: int) -> int:
"""Return the number of ISO weeks in a given year (52 or 53)."""
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
return dec_28.isocalendar().week
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
"""Return (start, end) datetimes for an ISO week.
Start = Monday 00:00:00 UTC of the given ISO week.
End = Monday 00:00:00 UTC of the following ISO week.
"""
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
if week < iso_weeks_in_year(year):
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
else:
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
return start, end
def _build_partition_classes() -> dict[str, type]:
"""Generate one ORM class per ISO week partition."""
classes: dict[str, type] = {}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
class_name = f"PostsWeek{year}W{week:02d}"
table_name = f"posts_{year}_{week:02d}"
partition_class = type(
class_name,
(PostsColumns, DataScienceDevBase),
{
"__tablename__": table_name,
"__table_args__": ({"implicit_returning": False},),
},
)
classes[class_name] = partition_class
return classes
# Generate all partition classes and register them on this module
_partition_classes = _build_partition_classes()
for _name, _cls in _partition_classes.items():
setattr(_current_module, _name, _cls)
__all__ = list(_partition_classes.keys())

View File

@@ -0,0 +1,13 @@
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
from __future__ import annotations
from pipelines.orm.data_science_dev.base import DataScienceDevBase
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
class Posts(PostsColumns, DataScienceDevBase):
"""Parent partitioned table for posts, partitioned by week on `date`."""
__tablename__ = "posts"
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)

View File

@@ -22,4 +22,4 @@ COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prom
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"] ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]

View File

@@ -23,9 +23,14 @@ import httpx
import typer import typer
from tiktoken import Encoding, get_encoding from tiktoken import Encoding, get_encoding
from python.prompt_bench.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml" _PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
@@ -72,7 +77,12 @@ def build_request(custom_id: str, model: str, bill_text: str) -> dict:
"model": model, "model": model,
"messages": [ "messages": [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT}, {"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)}, {
"role": "user",
"content": SUMMARIZATION_USER_TEMPLATE.format(
text_content=bill_text
),
},
], ],
}, },
} }
@@ -123,7 +133,9 @@ def prepare_requests(
"compressed_chars": len(compressed_text), "compressed_chars": len(compressed_text),
"raw_tokens": raw_token_count, "raw_tokens": raw_token_count,
"compressed_tokens": compressed_token_count, "compressed_tokens": compressed_token_count,
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None, "token_ratio": (compressed_token_count / raw_token_count)
if raw_token_count
else None,
}, },
) )
safe_id = safe_filename(bill_id) safe_id = safe_filename(bill_id)
@@ -136,7 +148,14 @@ def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
with path.open("w", newline="", encoding="utf-8") as handle: with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter( writer = csv.DictWriter(
handle, handle,
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"], fieldnames=[
"bill_id",
"raw_chars",
"compressed_chars",
"raw_tokens",
"compressed_tokens",
"token_ratio",
],
) )
writer.writeheader() writer.writeheader()
writer.writerows(token_rows) writer.writerows(token_rows)
@@ -161,8 +180,12 @@ def create_batch(client: httpx.Client, input_file_id: str, description: str) ->
def main( def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"), csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path( "bills.csv"
),
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")
] = Path(
"output/openai_batch", "output/openai_batch",
), ),
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini", model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
@@ -170,7 +193,9 @@ def main(
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Submit an OpenAI Batch job of compressed bill summaries.""" """Submit an OpenAI Batch job of compressed bill summaries."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY") api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
@@ -191,7 +216,9 @@ def main(
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder) request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
token_csv_path = output_dir / "token_counts.csv" token_csv_path = output_dir / "token_counts.csv"
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows) raw_tokens_total, compressed_tokens_total = write_token_csv(
token_csv_path, token_rows
)
logger.info( logger.info(
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s", "Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
raw_tokens_total, raw_tokens_total,
@@ -211,7 +238,11 @@ def main(
logger.info("Uploaded: %s", file_id) logger.info("Uploaded: %s", file_id)
logger.info("Creating batch") logger.info("Creating batch")
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})") batch = create_batch(
client,
file_id,
f"compressed bill summaries x{len(request_lines)} ({model})",
)
logger.info("Batch created: %s", batch["id"]) logger.info("Batch created: %s", batch["id"])
metadata = { metadata = {

View File

@@ -24,9 +24,14 @@ from typing import Annotated
import httpx import httpx
import typer import typer
from python.prompt_bench.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml" _PROMPTS_PATH = (
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
@@ -62,7 +67,10 @@ def build_messages(bill_text: str) -> list[dict]:
"""Return the system + user message pair for a bill.""" """Return the system + user message pair for a bill."""
return [ return [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT}, {"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)}, {
"role": "user",
"content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text),
},
] ]
@@ -132,17 +140,25 @@ def run_one_request(
def main( def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"), csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path( "bills.csv"
),
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to write per-request JSON")
] = Path(
"output/openai_runs", "output/openai_runs",
), ),
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL, model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT, count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16, concurrency: Annotated[
int, typer.Option(help="Concurrent in-flight requests")
] = 16,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text.""" """Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY") api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
@@ -165,8 +181,17 @@ def main(
tasks: list[tuple[str, str, str, Path]] = [] tasks: list[tuple[str, str, str, Path]] = []
for bill_id, text_content in bills: for bill_id, text_content in bills:
filename = f"{safe_filename(bill_id)}.json" filename = f"{safe_filename(bill_id)}.json"
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename)) tasks.append(
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename)) (
bill_id,
"compressed",
compress_bill_text(text_content),
compressed_dir / filename,
)
)
tasks.append(
(bill_id, "uncompressed", text_content, uncompressed_dir / filename)
)
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency) logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)

View File

@@ -9,13 +9,13 @@ from typing import Annotated
import typer import typer
from python.prompt_bench.containers.lib import check_gpu_free from pipelines.tools.containers.lib import check_gpu_free
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CONTAINER_NAME = "bill-finetune" CONTAINER_NAME = "bill-finetune"
FINETUNE_IMAGE = "bill-finetune:latest" FINETUNE_IMAGE = "bill-finetune:latest"
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune" REPO_DIR = Path(__file__).resolve().parents[4]
DEFAULT_HF_CACHE = Path("/zfs/models/hf") DEFAULT_HF_CACHE = Path("/zfs/models/hf")
@@ -23,7 +23,15 @@ def build_image() -> None:
"""Build the fine-tuning Docker image.""" """Build the fine-tuning Docker image."""
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE) logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
result = subprocess.run( result = subprocess.run(
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."], [
"docker",
"build",
"-f",
str(REPO_DIR / "python/prompt_bench/Dockerfile.finetune"),
"-t",
FINETUNE_IMAGE,
".",
],
text=True, text=True,
check=False, check=False,
) )
@@ -95,7 +103,9 @@ def stop_finetune() -> None:
"""Stop and remove the fine-tuning container.""" """Stop and remove the fine-tuning container."""
logger.info("Stopping fine-tuning container") logger.info("Stopping fine-tuning container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False) subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False) subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False
)
def logs_finetune() -> str | None: def logs_finetune() -> str | None:
@@ -122,17 +132,20 @@ def build() -> None:
@app.command() @app.command()
def run( def run(
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path( dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
"/home/richie/dotfiles/data/finetune_dataset.jsonl" / "data/finetune_dataset.jsonl",
), output_dir: Annotated[
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path( Path, typer.Option(help="Where to save the trained model")
"/home/richie/dotfiles/data/output/qwen-bill-summarizer", ] = REPO_DIR / "data/output/qwen-bill-summarizer",
), hf_cache: Annotated[
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE, Path, typer.Option(help="Host path to HuggingFace model cache")
] = DEFAULT_HF_CACHE,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run fine-tuning inside a Docker container.""" """Run fine-tuning inside a Docker container."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
check_gpu_free() check_gpu_free()
start_finetune( start_finetune(
dataset_path=dataset, dataset_path=dataset,
@@ -140,6 +153,7 @@ def run(
hf_cache=hf_cache, hf_cache=hf_cache,
) )
@app.command() @app.command()
def stop() -> None: def stop() -> None:
"""Stop and remove the fine-tuning container.""" """Stop and remove the fine-tuning container."""

View File

@@ -9,7 +9,7 @@ from typing import Annotated
import typer import typer
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from python.prompt_bench.models import BenchmarkConfig from pipelines.tools.models import BenchmarkConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,11 +52,15 @@ def download_all(config: BenchmarkConfig) -> None:
def main( def main(
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"), config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
"bench.toml"
),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Download all models listed in the benchmark config.""" """Download all models listed in the benchmark config."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not config.is_file(): if not config.is_file():
message = f"Config file does not exist: {config}" message = f"Config file does not exist: {config}"

View File

@@ -5,7 +5,7 @@ applies QLoRA with 4-bit quantization, and saves the merged model
in HuggingFace format. Designed for a single RTX 3090 (24GB). in HuggingFace format. Designed for a single RTX 3090 (24GB).
Usage: Usage:
python -m python.prompt_bench.finetune \ python -m pipelines.prompt_bench.finetune \
--dataset output/finetune_dataset.jsonl \ --dataset output/finetune_dataset.jsonl \
--output-dir output/qwen-bill-summarizer --output-dir output/qwen-bill-summarizer
""" """
@@ -107,21 +107,31 @@ def load_dataset_from_jsonl(path: Path) -> Dataset:
def main( def main(
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path( dataset_path: Annotated[
Path, typer.Option("--dataset", help="Fine-tuning JSONL")
] = Path(
"output/finetune_dataset.jsonl", "output/finetune_dataset.jsonl",
), ),
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1, validation_split: Annotated[
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path( float, typer.Option("--val-split", help="Fraction held out for validation")
] = 0.1,
output_dir: Annotated[
Path, typer.Option("--output-dir", help="Where to save the merged model")
] = Path(
"output/qwen-bill-summarizer", "output/qwen-bill-summarizer",
), ),
config_path: Annotated[ config_path: Annotated[
Path, Path,
typer.Option("--config", help="TOML config file"), typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml", ] = Path(__file__).parent / "config.toml",
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False, save_gguf: Annotated[
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
] = False,
) -> None: ) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA.""" """Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not dataset_path.is_file(): if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}" message = f"Dataset not found: {dataset_path}"
@@ -137,7 +147,9 @@ def main(
dtype=None, dtype=None,
) )
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha) logger.info(
"Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha
)
model = FastLanguageModel.get_peft_model( model = FastLanguageModel.get_peft_model(
model, model,
r=config.lora.rank, r=config.lora.rank,
@@ -153,7 +165,9 @@ def main(
split = full_dataset.train_test_split(test_size=validation_split, seed=42) split = full_dataset.train_test_split(test_size=validation_split, seed=42)
train_dataset = split["train"] train_dataset = split["train"]
validation_dataset = split["test"] validation_dataset = split["test"]
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset)) logger.info(
"Split: %d train, %d validation", len(train_dataset), len(validation_dataset)
)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"), output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.training.epochs, num_train_epochs=config.training.epochs,

View File

@@ -11,11 +11,11 @@ from typing import Annotated
import typer import typer
from python.prompt_bench.containers.lib import check_gpu_free from pipelines.tools.containers.lib import check_gpu_free
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm from pipelines.tools.containers.vllm import start_vllm, stop_vllm
from python.prompt_bench.downloader import is_model_present from pipelines.tools.downloader import is_model_present
from python.prompt_bench.models import BenchmarkConfig from pipelines.tools.models import BenchmarkConfig
from python.prompt_bench.vllm_client import VLLMClient from pipelines.tools.vllm_client import VLLMClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,7 +72,9 @@ def benchmark_model(
vLLM batches concurrent requests internally, so submitting many at once is vLLM batches concurrent requests internally, so submitting many at once is
significantly faster than running them serially. significantly faster than running them serially.
""" """
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()] pending = [
prompt for prompt in prompts if not (model_output / prompt.name).exists()
]
skipped = len(prompts) - len(pending) skipped = len(prompts) - len(pending)
if skipped: if skipped:
logger.info("Skipping %d prompts with existing output for %s", skipped, repo) logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
@@ -185,13 +187,21 @@ def run_benchmark(
def main( def main(
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")], input_dir: Annotated[
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"), Path, typer.Argument(help="Directory containing input .txt prompt files")
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"), ],
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
"bench.toml"
),
output_dir: Annotated[
Path, typer.Option(help="Output directory for results")
] = Path("output"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None: ) -> None:
"""Run prompts through multiple LLMs via vLLM and save results.""" """Run prompts through multiple LLMs via vLLM and save results."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logging.basicConfig(
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
if not input_dir.is_dir(): if not input_dir.is_dir():
message = f"Input directory does not exist: {input_dir}" message = f"Input directory does not exist: {input_dir}"

View File

@@ -1 +0,0 @@
how many oceans are there in the world

View File

@@ -1 +0,0 @@
whos the president of the united states

View File

@@ -1 +0,0 @@
whats the greatest country in the world

View File

@@ -1 +0,0 @@
was/is the usa the greatest country in the world