Compare commits
8 Commits
97bc78a6ef
...
matt_ds
| Author | SHA1 | Date | |
|---|---|---|---|
| e0f88c126e | |||
| 716bed5300 | |||
| 1e9c2a6caa | |||
| db3583e7f2 | |||
| a2cb640481 | |||
| b8d64a5b19 | |||
| 2abd61d3b1 | |||
| 7979dc3328 |
216
.gitignore
vendored
Normal file
216
.gitignore
vendored
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
7
pipelines/orm/__init__.py
Normal file
7
pipelines/orm/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""ORM package exports."""
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevBase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataScienceDevBase",
|
||||||
|
]
|
||||||
51
pipelines/orm/common.py
Normal file
51
pipelines/orm/common.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Shared ORM definitions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from os import getenv
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.engine import URL, Engine
|
||||||
|
|
||||||
|
NAMING_CONVENTION = {
|
||||||
|
"ix": "ix_%(table_name)s_%(column_0_name)s",
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
"pk": "pk_%(table_name)s",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
||||||
|
"""Get connection info from environment variables."""
|
||||||
|
database = getenv(f"{name}_DB")
|
||||||
|
host = getenv(f"{name}_HOST")
|
||||||
|
port = getenv(f"{name}_PORT")
|
||||||
|
username = getenv(f"{name}_USER")
|
||||||
|
password = getenv(f"{name}_PASSWORD")
|
||||||
|
|
||||||
|
if None in (database, host, port, username):
|
||||||
|
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
||||||
|
raise ValueError(error)
|
||||||
|
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
||||||
|
|
||||||
|
|
||||||
|
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
||||||
|
"""Create a SQLAlchemy engine from environment variables."""
|
||||||
|
database, host, port, username, password = get_connection_info(name)
|
||||||
|
|
||||||
|
url = URL.create(
|
||||||
|
drivername="postgresql+psycopg",
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
host=host,
|
||||||
|
port=int(port),
|
||||||
|
database=database,
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_engine(
|
||||||
|
url=url,
|
||||||
|
pool_pre_ping=pool_pre_ping,
|
||||||
|
pool_recycle=1800,
|
||||||
|
)
|
||||||
15
pipelines/orm/data_science_dev/__init__.py
Normal file
15
pipelines/orm/data_science_dev/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Data science dev database ORM exports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import (
|
||||||
|
DataScienceDevBase,
|
||||||
|
DataScienceDevTableBase,
|
||||||
|
DataScienceDevTableBaseBig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataScienceDevBase",
|
||||||
|
"DataScienceDevTableBase",
|
||||||
|
"DataScienceDevTableBaseBig",
|
||||||
|
]
|
||||||
52
pipelines/orm/data_science_dev/base.py
Normal file
52
pipelines/orm/data_science_dev/base.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Data science dev database ORM base."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, DateTime, MetaData, func
|
||||||
|
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from pipelines.orm.common import NAMING_CONVENTION
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevBase(DeclarativeBase):
|
||||||
|
"""Base class for data_science_dev database ORM models."""
|
||||||
|
|
||||||
|
schema_name = "main"
|
||||||
|
|
||||||
|
metadata = MetaData(
|
||||||
|
schema=schema_name,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TableMixin:
|
||||||
|
"""Shared timestamp columns for all table bases."""
|
||||||
|
|
||||||
|
created: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
updated: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||||
|
"""Table with Integer primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||||
|
"""Table with BigInteger primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
17
pipelines/orm/data_science_dev/congress/__init__.py
Normal file
17
pipelines/orm/data_science_dev/congress/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""init."""
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.congress.bill import Bill, BillText
|
||||||
|
from pipelines.orm.data_science_dev.congress.legislator import (
|
||||||
|
Legislator,
|
||||||
|
LegislatorSocialMedia,
|
||||||
|
)
|
||||||
|
from pipelines.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Bill",
|
||||||
|
"BillText",
|
||||||
|
"Legislator",
|
||||||
|
"LegislatorSocialMedia",
|
||||||
|
"Vote",
|
||||||
|
"VoteRecord",
|
||||||
|
]
|
||||||
72
pipelines/orm/data_science_dev/congress/bill.py
Normal file
72
pipelines/orm/data_science_dev/congress/bill.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Bill model - legislation introduced in Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pipelines.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
class Bill(DataScienceDevTableBase):
|
||||||
|
"""Legislation with congress number, type, titles, status, and sponsor."""
|
||||||
|
|
||||||
|
__tablename__ = "bill"
|
||||||
|
|
||||||
|
congress: Mapped[int]
|
||||||
|
bill_type: Mapped[str]
|
||||||
|
number: Mapped[int]
|
||||||
|
|
||||||
|
title: Mapped[str | None]
|
||||||
|
title_short: Mapped[str | None]
|
||||||
|
official_title: Mapped[str | None]
|
||||||
|
|
||||||
|
status: Mapped[str | None]
|
||||||
|
status_at: Mapped[date | None]
|
||||||
|
|
||||||
|
sponsor_bioguide_id: Mapped[str | None]
|
||||||
|
|
||||||
|
subjects_top_term: Mapped[str | None]
|
||||||
|
|
||||||
|
votes: Mapped[list[Vote]] = relationship(
|
||||||
|
"Vote",
|
||||||
|
back_populates="bill",
|
||||||
|
)
|
||||||
|
bill_texts: Mapped[list[BillText]] = relationship(
|
||||||
|
"BillText",
|
||||||
|
back_populates="bill",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress", "bill_type", "number", name="uq_bill_congress_type_number"
|
||||||
|
),
|
||||||
|
Index("ix_bill_congress", "congress"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillText(DataScienceDevTableBase):
|
||||||
|
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_text"
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
version_code: Mapped[str]
|
||||||
|
version_name: Mapped[str | None]
|
||||||
|
text_content: Mapped[str | None]
|
||||||
|
date: Mapped[date | None]
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"bill_id", "version_code", name="uq_bill_text_bill_id_version_code"
|
||||||
|
),
|
||||||
|
)
|
||||||
68
pipelines/orm/data_science_dev/congress/legislator.py
Normal file
68
pipelines/orm/data_science_dev/congress/legislator.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Legislator model - members of Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pipelines.orm.data_science_dev.congress.vote import VoteRecord
|
||||||
|
|
||||||
|
|
||||||
|
class Legislator(DataScienceDevTableBase):
|
||||||
|
"""Members of Congress with identification and current term info."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator"
|
||||||
|
|
||||||
|
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
||||||
|
|
||||||
|
thomas_id: Mapped[str | None]
|
||||||
|
lis_id: Mapped[str | None]
|
||||||
|
govtrack_id: Mapped[int | None]
|
||||||
|
opensecrets_id: Mapped[str | None]
|
||||||
|
fec_ids: Mapped[str | None]
|
||||||
|
|
||||||
|
first_name: Mapped[str]
|
||||||
|
last_name: Mapped[str]
|
||||||
|
official_full_name: Mapped[str | None]
|
||||||
|
nickname: Mapped[str | None]
|
||||||
|
|
||||||
|
birthday: Mapped[date | None]
|
||||||
|
gender: Mapped[str | None]
|
||||||
|
|
||||||
|
current_party: Mapped[str | None]
|
||||||
|
current_state: Mapped[str | None]
|
||||||
|
current_district: Mapped[int | None]
|
||||||
|
current_chamber: Mapped[str | None]
|
||||||
|
|
||||||
|
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
|
||||||
|
"LegislatorSocialMedia",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||||
|
"VoteRecord",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LegislatorSocialMedia(DataScienceDevTableBase):
|
||||||
|
"""Social media account linked to a legislator."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator_social_media"
|
||||||
|
|
||||||
|
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
|
||||||
|
platform: Mapped[str]
|
||||||
|
account_name: Mapped[str]
|
||||||
|
url: Mapped[str | None]
|
||||||
|
source: Mapped[str]
|
||||||
|
|
||||||
|
legislator: Mapped[Legislator] = relationship(
|
||||||
|
back_populates="social_media_accounts"
|
||||||
|
)
|
||||||
84
pipelines/orm/data_science_dev/congress/vote.py
Normal file
84
pipelines/orm/data_science_dev/congress/vote.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Vote model - roll call votes in Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import (
|
||||||
|
DataScienceDevBase,
|
||||||
|
DataScienceDevTableBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pipelines.orm.data_science_dev.congress.bill import Bill
|
||||||
|
from pipelines.orm.data_science_dev.congress.legislator import Legislator
|
||||||
|
from pipelines.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
class VoteRecord(DataScienceDevBase):
|
||||||
|
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_record"
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
legislator_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
position: Mapped[str]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
|
||||||
|
legislator: Mapped[Legislator] = relationship(
|
||||||
|
"Legislator", back_populates="vote_records"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Vote(DataScienceDevTableBase):
|
||||||
|
"""Roll call votes with counts and optional bill linkage."""
|
||||||
|
|
||||||
|
__tablename__ = "vote"
|
||||||
|
|
||||||
|
congress: Mapped[int]
|
||||||
|
chamber: Mapped[str]
|
||||||
|
session: Mapped[int]
|
||||||
|
number: Mapped[int]
|
||||||
|
|
||||||
|
vote_type: Mapped[str | None]
|
||||||
|
question: Mapped[str | None]
|
||||||
|
result: Mapped[str | None]
|
||||||
|
result_text: Mapped[str | None]
|
||||||
|
|
||||||
|
vote_date: Mapped[date]
|
||||||
|
|
||||||
|
yea_count: Mapped[int | None]
|
||||||
|
nay_count: Mapped[int | None]
|
||||||
|
not_voting_count: Mapped[int | None]
|
||||||
|
present_count: Mapped[int | None]
|
||||||
|
|
||||||
|
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
|
||||||
|
|
||||||
|
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
|
||||||
|
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||||
|
"VoteRecord",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session",
|
||||||
|
"number",
|
||||||
|
name="uq_vote_congress_chamber_session_number",
|
||||||
|
),
|
||||||
|
Index("ix_vote_date", "vote_date"),
|
||||||
|
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||||
|
)
|
||||||
16
pipelines/orm/data_science_dev/models.py
Normal file
16
pipelines/orm/data_science_dev/models.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Data science dev database ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
|
||||||
|
from pipelines.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
||||||
|
from pipelines.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Bill",
|
||||||
|
"BillText",
|
||||||
|
"Legislator",
|
||||||
|
"Posts",
|
||||||
|
"Vote",
|
||||||
|
"VoteRecord",
|
||||||
|
]
|
||||||
11
pipelines/orm/data_science_dev/posts/__init__.py
Normal file
11
pipelines/orm/data_science_dev/posts/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Posts module — weekly-partitioned posts table and partition ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
|
||||||
|
from pipelines.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FailedIngestion",
|
||||||
|
"Posts",
|
||||||
|
]
|
||||||
33
pipelines/orm/data_science_dev/posts/columns.py
Normal file
33
pipelines/orm/data_science_dev/posts/columns.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Shared column definitions for the posts partitioned table family."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, SmallInteger, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class PostsColumns:
|
||||||
|
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
|
||||||
|
|
||||||
|
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger)
|
||||||
|
instance: Mapped[str]
|
||||||
|
date: Mapped[datetime] = mapped_column(primary_key=True)
|
||||||
|
text: Mapped[str] = mapped_column(Text)
|
||||||
|
langs: Mapped[str | None]
|
||||||
|
like_count: Mapped[int]
|
||||||
|
reply_count: Mapped[int]
|
||||||
|
repost_count: Mapped[int]
|
||||||
|
reply_to: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
replied_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
thread_root: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
repost_from: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
quotes: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
labels: Mapped[str | None]
|
||||||
|
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
|
||||||
|
sent_score: Mapped[float | None]
|
||||||
17
pipelines/orm/data_science_dev/posts/failed_ingestion.py
Normal file
17
pipelines/orm/data_science_dev/posts/failed_ingestion.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Table for storing JSONL lines that failed during post ingestion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
|
||||||
|
class FailedIngestion(DataScienceDevTableBase):
|
||||||
|
"""Stores raw JSONL lines and their error messages when ingestion fails."""
|
||||||
|
|
||||||
|
__tablename__ = "failed_ingestion"
|
||||||
|
|
||||||
|
raw_line: Mapped[str] = mapped_column(Text)
|
||||||
|
error: Mapped[str] = mapped_column(Text)
|
||||||
71
pipelines/orm/data_science_dev/posts/partitions.py
Normal file
71
pipelines/orm/data_science_dev/posts/partitions.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Dynamically generated ORM classes for each weekly partition of the posts table.
|
||||||
|
|
||||||
|
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
|
||||||
|
These are real ORM models tracked by Alembic autogenerate.
|
||||||
|
|
||||||
|
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
|
||||||
|
52 or 53 weeks, and week boundaries are always Monday to Monday.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevBase
|
||||||
|
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
|
||||||
|
|
||||||
|
PARTITION_START_YEAR = 2023
|
||||||
|
PARTITION_END_YEAR = 2026
|
||||||
|
|
||||||
|
_current_module = sys.modules[__name__]
|
||||||
|
|
||||||
|
|
||||||
|
def iso_weeks_in_year(year: int) -> int:
|
||||||
|
"""Return the number of ISO weeks in a given year (52 or 53)."""
|
||||||
|
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
|
||||||
|
return dec_28.isocalendar().week
|
||||||
|
|
||||||
|
|
||||||
|
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
|
||||||
|
"""Return (start, end) datetimes for an ISO week.
|
||||||
|
|
||||||
|
Start = Monday 00:00:00 UTC of the given ISO week.
|
||||||
|
End = Monday 00:00:00 UTC of the following ISO week.
|
||||||
|
"""
|
||||||
|
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
|
||||||
|
if week < iso_weeks_in_year(year):
|
||||||
|
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
|
||||||
|
else:
|
||||||
|
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def _build_partition_classes() -> dict[str, type]:
|
||||||
|
"""Generate one ORM class per ISO week partition."""
|
||||||
|
classes: dict[str, type] = {}
|
||||||
|
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
class_name = f"PostsWeek{year}W{week:02d}"
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
|
||||||
|
partition_class = type(
|
||||||
|
class_name,
|
||||||
|
(PostsColumns, DataScienceDevBase),
|
||||||
|
{
|
||||||
|
"__tablename__": table_name,
|
||||||
|
"__table_args__": ({"implicit_returning": False},),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
classes[class_name] = partition_class
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
|
||||||
|
# Generate all partition classes and register them on this module
|
||||||
|
_partition_classes = _build_partition_classes()
|
||||||
|
for _name, _cls in _partition_classes.items():
|
||||||
|
setattr(_current_module, _name, _cls)
|
||||||
|
__all__ = list(_partition_classes.keys())
|
||||||
13
pipelines/orm/data_science_dev/posts/tables.py
Normal file
13
pipelines/orm/data_science_dev/posts/tables.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pipelines.orm.data_science_dev.base import DataScienceDevBase
|
||||||
|
from pipelines.orm.data_science_dev.posts.columns import PostsColumns
|
||||||
|
|
||||||
|
|
||||||
|
class Posts(PostsColumns, DataScienceDevBase):
|
||||||
|
"""Parent partitioned table for posts, partitioned by week on `date`."""
|
||||||
|
|
||||||
|
__tablename__ = "posts"
|
||||||
|
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
|
||||||
26
pipelines/tools/Dockerfile.finetune
Normal file
26
pipelines/tools/Dockerfile.finetune
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
||||||
|
#
|
||||||
|
# Build:
|
||||||
|
# docker build -f pipelines/pipelines/tools/Dockerfile.finetune -t bill-finetune .
|
||||||
|
#
|
||||||
|
# Run:
|
||||||
|
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
||||||
|
# -v $(pwd)/output:/workspace/output \
|
||||||
|
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
|
||||||
|
# -v /zfs/models/hf:/models \
|
||||||
|
# bill-finetune \
|
||||||
|
# --dataset /workspace/dataset.jsonl \
|
||||||
|
# --output-dir /workspace/output/qwen-bill-summarizer
|
||||||
|
|
||||||
|
FROM ghcr.io/unslothai/unsloth:latest
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir typer rouge-score
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY pipelines/tools/__init__.py pipelines/tools/__init__.py
|
||||||
|
COPY pipelines/tools/finetune.py pipelines/tools/finetune.py
|
||||||
|
COPY pipelines/tools/summarization_eval.py pipelines/tools/summarization_eval.py
|
||||||
|
COPY summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||||
|
COPY config.toml pipelines/tools/config.toml
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "pipelines.tools.finetune"]
|
||||||
@@ -23,9 +23,14 @@ import httpx
|
|||||||
import typer
|
import typer
|
||||||
from tiktoken import Encoding, get_encoding
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
from python.prompt_bench.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml"
|
_PROMPTS_PATH = (
|
||||||
|
Path(__file__).resolve().parents[2]
|
||||||
|
/ "config"
|
||||||
|
/ "prompts"
|
||||||
|
/ "summarization_prompts.toml"
|
||||||
|
)
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
@@ -72,7 +77,12 @@ def build_request(custom_id: str, model: str, bill_text: str) -> dict:
|
|||||||
"model": model,
|
"model": model,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
||||||
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SUMMARIZATION_USER_TEMPLATE.format(
|
||||||
|
text_content=bill_text
|
||||||
|
),
|
||||||
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -123,7 +133,9 @@ def prepare_requests(
|
|||||||
"compressed_chars": len(compressed_text),
|
"compressed_chars": len(compressed_text),
|
||||||
"raw_tokens": raw_token_count,
|
"raw_tokens": raw_token_count,
|
||||||
"compressed_tokens": compressed_token_count,
|
"compressed_tokens": compressed_token_count,
|
||||||
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None,
|
"token_ratio": (compressed_token_count / raw_token_count)
|
||||||
|
if raw_token_count
|
||||||
|
else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
safe_id = safe_filename(bill_id)
|
safe_id = safe_filename(bill_id)
|
||||||
@@ -136,7 +148,14 @@ def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
|
|||||||
with path.open("w", newline="", encoding="utf-8") as handle:
|
with path.open("w", newline="", encoding="utf-8") as handle:
|
||||||
writer = csv.DictWriter(
|
writer = csv.DictWriter(
|
||||||
handle,
|
handle,
|
||||||
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"],
|
fieldnames=[
|
||||||
|
"bill_id",
|
||||||
|
"raw_chars",
|
||||||
|
"compressed_chars",
|
||||||
|
"raw_tokens",
|
||||||
|
"compressed_tokens",
|
||||||
|
"token_ratio",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerows(token_rows)
|
writer.writerows(token_rows)
|
||||||
@@ -161,8 +180,12 @@ def create_batch(client: httpx.Client, input_file_id: str, description: str) ->
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
|
||||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path(
|
"bills.csv"
|
||||||
|
),
|
||||||
|
output_dir: Annotated[
|
||||||
|
Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")
|
||||||
|
] = Path(
|
||||||
"output/openai_batch",
|
"output/openai_batch",
|
||||||
),
|
),
|
||||||
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
|
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
|
||||||
@@ -170,7 +193,9 @@ def main(
|
|||||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Submit an OpenAI Batch job of compressed bill summaries."""
|
"""Submit an OpenAI Batch job of compressed bill summaries."""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@@ -191,7 +216,9 @@ def main(
|
|||||||
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
|
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
|
||||||
|
|
||||||
token_csv_path = output_dir / "token_counts.csv"
|
token_csv_path = output_dir / "token_counts.csv"
|
||||||
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows)
|
raw_tokens_total, compressed_tokens_total = write_token_csv(
|
||||||
|
token_csv_path, token_rows
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
|
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
|
||||||
raw_tokens_total,
|
raw_tokens_total,
|
||||||
@@ -211,7 +238,11 @@ def main(
|
|||||||
logger.info("Uploaded: %s", file_id)
|
logger.info("Uploaded: %s", file_id)
|
||||||
|
|
||||||
logger.info("Creating batch")
|
logger.info("Creating batch")
|
||||||
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})")
|
batch = create_batch(
|
||||||
|
client,
|
||||||
|
file_id,
|
||||||
|
f"compressed bill summaries x{len(request_lines)} ({model})",
|
||||||
|
)
|
||||||
logger.info("Batch created: %s", batch["id"])
|
logger.info("Batch created: %s", batch["id"])
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
@@ -24,9 +24,14 @@ from typing import Annotated
|
|||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from python.prompt_bench.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml"
|
_PROMPTS_PATH = (
|
||||||
|
Path(__file__).resolve().parents[2]
|
||||||
|
/ "config"
|
||||||
|
/ "prompts"
|
||||||
|
/ "summarization_prompts.toml"
|
||||||
|
)
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
@@ -62,7 +67,10 @@ def build_messages(bill_text: str) -> list[dict]:
|
|||||||
"""Return the system + user message pair for a bill."""
|
"""Return the system + user message pair for a bill."""
|
||||||
return [
|
return [
|
||||||
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
||||||
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text),
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -132,17 +140,25 @@ def run_one_request(
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path(
|
||||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path(
|
"bills.csv"
|
||||||
|
),
|
||||||
|
output_dir: Annotated[
|
||||||
|
Path, typer.Option("--output-dir", help="Where to write per-request JSON")
|
||||||
|
] = Path(
|
||||||
"output/openai_runs",
|
"output/openai_runs",
|
||||||
),
|
),
|
||||||
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
|
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
|
||||||
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
|
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
|
||||||
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16,
|
concurrency: Annotated[
|
||||||
|
int, typer.Option(help="Concurrent in-flight requests")
|
||||||
|
] = 16,
|
||||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
|
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@@ -165,8 +181,17 @@ def main(
|
|||||||
tasks: list[tuple[str, str, str, Path]] = []
|
tasks: list[tuple[str, str, str, Path]] = []
|
||||||
for bill_id, text_content in bills:
|
for bill_id, text_content in bills:
|
||||||
filename = f"{safe_filename(bill_id)}.json"
|
filename = f"{safe_filename(bill_id)}.json"
|
||||||
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename))
|
tasks.append(
|
||||||
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename))
|
(
|
||||||
|
bill_id,
|
||||||
|
"compressed",
|
||||||
|
compress_bill_text(text_content),
|
||||||
|
compressed_dir / filename,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
(bill_id, "uncompressed", text_content, uncompressed_dir / filename)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
|
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
|
||||||
|
|
||||||
@@ -9,13 +9,13 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from python.prompt_bench.containers.lib import check_gpu_free
|
from pipelines.tools.containers.lib import check_gpu_free
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CONTAINER_NAME = "bill-finetune"
|
CONTAINER_NAME = "bill-finetune"
|
||||||
FINETUNE_IMAGE = "bill-finetune:latest"
|
FINETUNE_IMAGE = "bill-finetune:latest"
|
||||||
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune"
|
REPO_DIR = Path(__file__).resolve().parents[4]
|
||||||
DEFAULT_HF_CACHE = Path("/zfs/models/hf")
|
DEFAULT_HF_CACHE = Path("/zfs/models/hf")
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +23,15 @@ def build_image() -> None:
|
|||||||
"""Build the fine-tuning Docker image."""
|
"""Build the fine-tuning Docker image."""
|
||||||
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
|
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."],
|
[
|
||||||
|
"docker",
|
||||||
|
"build",
|
||||||
|
"-f",
|
||||||
|
str(REPO_DIR / "pipelines/pipelines/tools/Dockerfile.finetune"),
|
||||||
|
"-t",
|
||||||
|
FINETUNE_IMAGE,
|
||||||
|
".",
|
||||||
|
],
|
||||||
text=True,
|
text=True,
|
||||||
check=False,
|
check=False,
|
||||||
)
|
)
|
||||||
@@ -95,7 +103,9 @@ def stop_finetune() -> None:
|
|||||||
"""Stop and remove the fine-tuning container."""
|
"""Stop and remove the fine-tuning container."""
|
||||||
logger.info("Stopping fine-tuning container")
|
logger.info("Stopping fine-tuning container")
|
||||||
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
|
subprocess.run(
|
||||||
|
["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def logs_finetune() -> str | None:
|
def logs_finetune() -> str | None:
|
||||||
@@ -122,17 +132,20 @@ def build() -> None:
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def run(
|
def run(
|
||||||
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path(
|
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
|
||||||
"/home/richie/dotfiles/data/finetune_dataset.jsonl"
|
/ "/zfs/storage/data_science/data/finetune_dataset.jsonl",
|
||||||
),
|
output_dir: Annotated[
|
||||||
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path(
|
Path, typer.Option(help="Where to save the trained model")
|
||||||
"/home/richie/dotfiles/data/output/qwen-bill-summarizer",
|
] = REPO_DIR / "data/output/qwen-bill-summarizer",
|
||||||
),
|
hf_cache: Annotated[
|
||||||
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE,
|
Path, typer.Option(help="Host path to HuggingFace model cache")
|
||||||
|
] = DEFAULT_HF_CACHE,
|
||||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run fine-tuning inside a Docker container."""
|
"""Run fine-tuning inside a Docker container."""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
check_gpu_free()
|
check_gpu_free()
|
||||||
start_finetune(
|
start_finetune(
|
||||||
dataset_path=dataset,
|
dataset_path=dataset,
|
||||||
@@ -140,6 +153,7 @@ def run(
|
|||||||
hf_cache=hf_cache,
|
hf_cache=hf_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def stop() -> None:
|
def stop() -> None:
|
||||||
"""Stop and remove the fine-tuning container."""
|
"""Stop and remove the fine-tuning container."""
|
||||||
@@ -9,7 +9,7 @@ from typing import Annotated
|
|||||||
import typer
|
import typer
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from python.prompt_bench.models import BenchmarkConfig
|
from pipelines.tools.models import BenchmarkConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -52,11 +52,15 @@ def download_all(config: BenchmarkConfig) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
|
||||||
|
"bench.toml"
|
||||||
|
),
|
||||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Download all models listed in the benchmark config."""
|
"""Download all models listed in the benchmark config."""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
if not config.is_file():
|
if not config.is_file():
|
||||||
message = f"Config file does not exist: {config}"
|
message = f"Config file does not exist: {config}"
|
||||||
@@ -5,7 +5,7 @@ applies QLoRA with 4-bit quantization, and saves the merged model
|
|||||||
in HuggingFace format. Designed for a single RTX 3090 (24GB).
|
in HuggingFace format. Designed for a single RTX 3090 (24GB).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python -m python.prompt_bench.finetune \
|
python -m pipelines.prompt_bench.finetune \
|
||||||
--dataset output/finetune_dataset.jsonl \
|
--dataset output/finetune_dataset.jsonl \
|
||||||
--output-dir output/qwen-bill-summarizer
|
--output-dir output/qwen-bill-summarizer
|
||||||
"""
|
"""
|
||||||
@@ -25,6 +25,8 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
from .summarization_eval import make_compute_metrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -107,21 +109,31 @@ def load_dataset_from_jsonl(path: Path) -> Dataset:
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path(
|
dataset_path: Annotated[
|
||||||
|
Path, typer.Option("--dataset", help="Fine-tuning JSONL")
|
||||||
|
] = Path(
|
||||||
"output/finetune_dataset.jsonl",
|
"output/finetune_dataset.jsonl",
|
||||||
),
|
),
|
||||||
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1,
|
validation_split: Annotated[
|
||||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
|
float, typer.Option("--val-split", help="Fraction held out for validation")
|
||||||
|
] = 0.1,
|
||||||
|
output_dir: Annotated[
|
||||||
|
Path, typer.Option("--output-dir", help="Where to save the merged model")
|
||||||
|
] = Path(
|
||||||
"output/qwen-bill-summarizer",
|
"output/qwen-bill-summarizer",
|
||||||
),
|
),
|
||||||
config_path: Annotated[
|
config_path: Annotated[
|
||||||
Path,
|
Path,
|
||||||
typer.Option("--config", help="TOML config file"),
|
typer.Option("--config", help="TOML config file"),
|
||||||
] = Path(__file__).parent / "config.toml",
|
] = Path(__file__).parent / "config.toml",
|
||||||
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
|
save_gguf: Annotated[
|
||||||
|
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
||||||
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
||||||
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
if not dataset_path.is_file():
|
if not dataset_path.is_file():
|
||||||
message = f"Dataset not found: {dataset_path}"
|
message = f"Dataset not found: {dataset_path}"
|
||||||
@@ -137,7 +149,9 @@ def main(
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
|
logger.info(
|
||||||
|
"Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha
|
||||||
|
)
|
||||||
model = FastLanguageModel.get_peft_model(
|
model = FastLanguageModel.get_peft_model(
|
||||||
model,
|
model,
|
||||||
r=config.lora.rank,
|
r=config.lora.rank,
|
||||||
@@ -153,7 +167,9 @@ def main(
|
|||||||
split = full_dataset.train_test_split(test_size=validation_split, seed=42)
|
split = full_dataset.train_test_split(test_size=validation_split, seed=42)
|
||||||
train_dataset = split["train"]
|
train_dataset = split["train"]
|
||||||
validation_dataset = split["test"]
|
validation_dataset = split["test"]
|
||||||
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
|
logger.info(
|
||||||
|
"Split: %d train, %d validation", len(train_dataset), len(validation_dataset)
|
||||||
|
)
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=str(output_dir / "checkpoints"),
|
output_dir=str(output_dir / "checkpoints"),
|
||||||
num_train_epochs=config.training.epochs,
|
num_train_epochs=config.training.epochs,
|
||||||
@@ -173,6 +189,9 @@ def main(
|
|||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
seed=42,
|
seed=42,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
|
metric_for_best_model="eval_composite",
|
||||||
|
greater_is_better=True,
|
||||||
|
predict_with_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
@@ -183,6 +202,7 @@ def main(
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
max_seq_length=config.training.max_seq_length,
|
max_seq_length=config.training.max_seq_length,
|
||||||
packing=True,
|
packing=True,
|
||||||
|
compute_metrics=make_compute_metrics(tokenizer),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -11,11 +11,11 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from python.prompt_bench.containers.lib import check_gpu_free
|
from pipelines.tools.containers.lib import check_gpu_free
|
||||||
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm
|
from pipelines.tools.containers.vllm import start_vllm, stop_vllm
|
||||||
from python.prompt_bench.downloader import is_model_present
|
from pipelines.tools.downloader import is_model_present
|
||||||
from python.prompt_bench.models import BenchmarkConfig
|
from pipelines.tools.models import BenchmarkConfig
|
||||||
from python.prompt_bench.vllm_client import VLLMClient
|
from pipelines.tools.vllm_client import VLLMClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -72,7 +72,9 @@ def benchmark_model(
|
|||||||
vLLM batches concurrent requests internally, so submitting many at once is
|
vLLM batches concurrent requests internally, so submitting many at once is
|
||||||
significantly faster than running them serially.
|
significantly faster than running them serially.
|
||||||
"""
|
"""
|
||||||
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
|
pending = [
|
||||||
|
prompt for prompt in prompts if not (model_output / prompt.name).exists()
|
||||||
|
]
|
||||||
skipped = len(prompts) - len(pending)
|
skipped = len(prompts) - len(pending)
|
||||||
if skipped:
|
if skipped:
|
||||||
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
||||||
@@ -185,13 +187,21 @@ def run_benchmark(
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
|
input_dir: Annotated[
|
||||||
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
Path, typer.Argument(help="Directory containing input .txt prompt files")
|
||||||
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
|
],
|
||||||
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path(
|
||||||
|
"bench.toml"
|
||||||
|
),
|
||||||
|
output_dir: Annotated[
|
||||||
|
Path, typer.Option(help="Output directory for results")
|
||||||
|
] = Path("output"),
|
||||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run prompts through multiple LLMs via vLLM and save results."""
|
"""Run prompts through multiple LLMs via vLLM and save results."""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
if not input_dir.is_dir():
|
if not input_dir.is_dir():
|
||||||
message = f"Input directory does not exist: {input_dir}"
|
message = f"Input directory does not exist: {input_dir}"
|
||||||
426
pipelines/tools/summarization_eval.py
Normal file
426
pipelines/tools/summarization_eval.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
"""Summarization evaluation for Congressional bill summaries.
|
||||||
|
|
||||||
|
Three use cases from one module:
|
||||||
|
|
||||||
|
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
|
||||||
|
from summarization_eval import filter_dataset
|
||||||
|
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
|
||||||
|
|
||||||
|
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
|
||||||
|
from summarization_eval import make_compute_metrics
|
||||||
|
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
|
||||||
|
|
||||||
|
3. Inference eval — score a finished model against held-out references:
|
||||||
|
from summarization_eval import evaluate_file
|
||||||
|
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from rouge_score import rouge_scorer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SECTION_HEADERS = [
|
||||||
|
"OPERATIVE ACTIONS",
|
||||||
|
"AFFECTED POPULATIONS",
|
||||||
|
"MECHANISMS",
|
||||||
|
"POLICY THREADS",
|
||||||
|
"SYMBOLIC/PROCEDURAL ONLY",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
|
||||||
|
ROUGE_WEIGHTS = {
|
||||||
|
"rouge1": 0.2,
|
||||||
|
"rouge2": 0.4,
|
||||||
|
"rougeL": 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Composite score floor below which a training example is considered low quality
|
||||||
|
FILTER_THRESHOLD = 0.25
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core data structures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SummaryScore:
|
||||||
|
"""Scores for a single (prediction, reference) pair."""
|
||||||
|
|
||||||
|
rouge1: float
|
||||||
|
rouge2: float
|
||||||
|
rougeL: float
|
||||||
|
composite: float
|
||||||
|
has_all_sections: bool # True = all 5 headers present
|
||||||
|
missing_sections: list[str]
|
||||||
|
structural_fail: bool # True = one or more headers missing (hard guardrail)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"rouge1": self.rouge1,
|
||||||
|
"rouge2": self.rouge2,
|
||||||
|
"rougeL": self.rougeL,
|
||||||
|
"composite": self.composite,
|
||||||
|
"has_all_sections": self.has_all_sections,
|
||||||
|
"missing_sections": self.missing_sections,
|
||||||
|
"structural_fail": self.structural_fail,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchResult:
|
||||||
|
"""Aggregate results over a batch of summaries."""
|
||||||
|
|
||||||
|
n_total: int
|
||||||
|
n_structural_fail: int
|
||||||
|
n_scored: int # excludes structural failures
|
||||||
|
rouge1_mean: float
|
||||||
|
rouge2_mean: float
|
||||||
|
rougeL_mean: float
|
||||||
|
composite_mean: float
|
||||||
|
scores: list[SummaryScore]
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"n_total": self.n_total,
|
||||||
|
"n_structural_fail": self.n_structural_fail,
|
||||||
|
"n_scored": self.n_scored,
|
||||||
|
"rouge1_mean": self.rouge1_mean,
|
||||||
|
"rouge2_mean": self.rouge2_mean,
|
||||||
|
"rougeL_mean": self.rougeL_mean,
|
||||||
|
"composite_mean": self.composite_mean,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core scoring
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_sections(text: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Return (all_present, missing_headers) for the 5 required section headers."""
|
||||||
|
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
|
||||||
|
return len(missing) == 0, missing
|
||||||
|
|
||||||
|
|
||||||
|
def score_pair(prediction: str, reference: str) -> SummaryScore:
|
||||||
|
"""Score a single (prediction, reference) pair.
|
||||||
|
|
||||||
|
If the prediction is missing any section header, structural_fail is True
|
||||||
|
and ROUGE scores are still computed (so you can inspect quality even on
|
||||||
|
structural failures) but the example should be treated as a guardrail failure.
|
||||||
|
"""
|
||||||
|
has_all, missing = check_sections(prediction)
|
||||||
|
|
||||||
|
rouge = _scorer.score(reference, prediction)
|
||||||
|
r1 = rouge["rouge1"].fmeasure
|
||||||
|
r2 = rouge["rouge2"].fmeasure
|
||||||
|
rl = rouge["rougeL"].fmeasure
|
||||||
|
composite = (
|
||||||
|
ROUGE_WEIGHTS["rouge1"] * r1
|
||||||
|
+ ROUGE_WEIGHTS["rouge2"] * r2
|
||||||
|
+ ROUGE_WEIGHTS["rougeL"] * rl
|
||||||
|
)
|
||||||
|
|
||||||
|
return SummaryScore(
|
||||||
|
rouge1=r1,
|
||||||
|
rouge2=r2,
|
||||||
|
rougeL=rl,
|
||||||
|
composite=composite,
|
||||||
|
has_all_sections=has_all,
|
||||||
|
missing_sections=missing,
|
||||||
|
structural_fail=not has_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
|
||||||
|
"""Score a list of (prediction, reference) pairs and return aggregate results.
|
||||||
|
|
||||||
|
Structural failures are counted separately and excluded from ROUGE means
|
||||||
|
so a batch with broken formatting doesn't drag down the score unfairly.
|
||||||
|
"""
|
||||||
|
scores = [score_pair(pred, ref) for pred, ref in pairs]
|
||||||
|
|
||||||
|
structural_fails = [s for s in scores if s.structural_fail]
|
||||||
|
valid = [s for s in scores if not s.structural_fail]
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
|
||||||
|
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
|
||||||
|
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
|
||||||
|
composite_mean = float(np.mean([s.composite for s in valid]))
|
||||||
|
else:
|
||||||
|
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
|
||||||
|
|
||||||
|
return BatchResult(
|
||||||
|
n_total=len(scores),
|
||||||
|
n_structural_fail=len(structural_fails),
|
||||||
|
n_scored=len(valid),
|
||||||
|
rouge1_mean=rouge1_mean,
|
||||||
|
rouge2_mean=rouge2_mean,
|
||||||
|
rougeL_mean=rougeL_mean,
|
||||||
|
composite_mean=composite_mean,
|
||||||
|
scores=scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 1: Data filtering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def filter_dataset(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
*,
|
||||||
|
threshold: float = FILTER_THRESHOLD,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
|
||||||
|
|
||||||
|
Each line must be a ChatML messages dict:
|
||||||
|
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
|
||||||
|
|
||||||
|
The assistant turn is the prediction. The reference is the same assistant
|
||||||
|
turn — filtering here uses composite score as a self-consistency check
|
||||||
|
against the threshold, and drops structural failures unconditionally.
|
||||||
|
|
||||||
|
In practice you'd call this after joining requests + GPT completions
|
||||||
|
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
|
||||||
|
or suspiciously short/low quality.
|
||||||
|
|
||||||
|
Returns (kept, dropped).
|
||||||
|
"""
|
||||||
|
input_path = Path(input_path)
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
dropped = 0
|
||||||
|
|
||||||
|
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
|
||||||
|
for line_num, raw_line in enumerate(fin, 1):
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
|
||||||
|
example = json.loads(stripped)
|
||||||
|
messages = example.get("messages", [])
|
||||||
|
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
|
||||||
|
|
||||||
|
if not assistant_turns:
|
||||||
|
logger.warning("Line %d: no assistant turn, dropping", line_num)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
prediction = assistant_turns[-1].get("content", "")
|
||||||
|
|
||||||
|
# Guardrail: drop if any section header missing
|
||||||
|
has_all, missing = check_sections(prediction)
|
||||||
|
if not has_all:
|
||||||
|
logger.warning(
|
||||||
|
"Line %d: structural fail (missing: %s), dropping",
|
||||||
|
line_num,
|
||||||
|
", ".join(missing),
|
||||||
|
)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Quality floor: score against itself isn't meaningful for filtering —
|
||||||
|
# instead just check composite score of prediction vs a simple
|
||||||
|
# word-count proxy. For filtering GPT outputs, structural check
|
||||||
|
# + a minimum word count is usually sufficient.
|
||||||
|
word_count = len(prediction.split())
|
||||||
|
if word_count < 80:
|
||||||
|
logger.warning(
|
||||||
|
"Line %d: too short (%d words), dropping", line_num, word_count
|
||||||
|
)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
|
||||||
|
kept += 1
|
||||||
|
|
||||||
|
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
|
||||||
|
return kept, dropped
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 2: compute_metrics hook for SFTTrainer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
|
||||||
|
"""Return a compute_metrics function compatible with HuggingFace Trainer.
|
||||||
|
|
||||||
|
Usage in finetune.py:
|
||||||
|
from summarization_eval import make_compute_metrics
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
...
|
||||||
|
compute_metrics=make_compute_metrics(tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
Note: EvalPrediction.predictions are logits (or token ids if
|
||||||
|
include_inputs_for_metrics is False). This function handles both.
|
||||||
|
For SFTTrainer with packing=True, you may need to set
|
||||||
|
predict_with_generate=True in TrainingArguments to get decoded text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
|
||||||
|
predictions, labels = eval_pred
|
||||||
|
|
||||||
|
# If predictions are logits, take argmax
|
||||||
|
if predictions.ndim == 3:
|
||||||
|
predictions = np.argmax(predictions, axis=-1)
|
||||||
|
|
||||||
|
# Mask out -100 padding in labels
|
||||||
|
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
|
||||||
|
|
||||||
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
||||||
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
|
pairs = list(zip(decoded_preds, decoded_labels))
|
||||||
|
result = score_batch(pairs)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"eval_rouge1": result.rouge1_mean,
|
||||||
|
"eval_rouge2": result.rouge2_mean,
|
||||||
|
"eval_rougeL": result.rougeL_mean,
|
||||||
|
"eval_composite": result.composite_mean,
|
||||||
|
"eval_structural_fail_rate": (
|
||||||
|
result.n_structural_fail / result.n_total if result.n_total else 0.0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
|
||||||
|
metrics["eval_composite"],
|
||||||
|
metrics["eval_rouge1"],
|
||||||
|
metrics["eval_rouge2"],
|
||||||
|
metrics["eval_rougeL"],
|
||||||
|
result.n_structural_fail,
|
||||||
|
result.n_total,
|
||||||
|
)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
return compute_metrics
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 3: Inference eval against held-out references
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def evaluate_file(
|
||||||
|
predictions_path: Path | str,
|
||||||
|
references_path: Path | str,
|
||||||
|
output_path: Path | str | None = None,
|
||||||
|
) -> BatchResult:
|
||||||
|
"""Score a predictions JSONL against a references JSONL.
|
||||||
|
|
||||||
|
Both files should be line-matched: line N of predictions corresponds
|
||||||
|
to line N of references. Each line should be a plain JSON object with
|
||||||
|
a "text" or "content" key, or a ChatML messages dict.
|
||||||
|
|
||||||
|
If output_path is provided, writes per-example scores as JSONL.
|
||||||
|
"""
|
||||||
|
predictions_path = Path(predictions_path)
|
||||||
|
references_path = Path(references_path)
|
||||||
|
|
||||||
|
def extract_text(line: str) -> str:
|
||||||
|
obj = json.loads(line)
|
||||||
|
# Plain text field
|
||||||
|
if "text" in obj:
|
||||||
|
return obj["text"]
|
||||||
|
if "content" in obj:
|
||||||
|
return obj["content"]
|
||||||
|
# ChatML messages — take last assistant turn
|
||||||
|
messages = obj.get("messages", [])
|
||||||
|
for m in reversed(messages):
|
||||||
|
if m.get("role") == "assistant":
|
||||||
|
return m.get("content", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
|
||||||
|
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
|
||||||
|
|
||||||
|
if len(preds) != len(refs):
|
||||||
|
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
result = score_batch(list(zip(preds, refs)))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Inference eval: n=%d structural_fails=%d composite=%.4f "
|
||||||
|
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
|
||||||
|
result.n_total,
|
||||||
|
result.n_structural_fail,
|
||||||
|
result.composite_mean,
|
||||||
|
result.rouge1_mean,
|
||||||
|
result.rouge2_mean,
|
||||||
|
result.rougeL_mean,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with output_path.open("w", encoding="utf-8") as fout:
|
||||||
|
for score in result.scores:
|
||||||
|
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
|
||||||
|
summary_path = output_path.with_suffix(".summary.json")
|
||||||
|
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
|
||||||
|
logger.info("Wrote per-example scores to %s", output_path)
|
||||||
|
logger.info("Wrote summary to %s", summary_path)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI — quick sanity check / standalone use
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _cli() -> None:
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# filter subcommand
|
||||||
|
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
|
||||||
|
fp.add_argument("--input", required=True, type=Path)
|
||||||
|
fp.add_argument("--output", required=True, type=Path)
|
||||||
|
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
|
||||||
|
|
||||||
|
# eval subcommand
|
||||||
|
ep = subparsers.add_parser("eval", help="Score predictions against references")
|
||||||
|
ep.add_argument("--predictions", required=True, type=Path)
|
||||||
|
ep.add_argument("--references", required=True, type=Path)
|
||||||
|
ep.add_argument("--output", type=Path, default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
if args.command == "filter":
|
||||||
|
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
|
||||||
|
print(f"Kept: {kept} Dropped: {dropped}")
|
||||||
|
|
||||||
|
elif args.command == "eval":
|
||||||
|
result = evaluate_file(args.predictions, args.references, args.output)
|
||||||
|
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
|
||||||
|
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
|
||||||
|
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
|
||||||
|
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
|
||||||
|
print(f" Composite: {result.composite_mean:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_cli()
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
|
||||||
#
|
|
||||||
# Build:
|
|
||||||
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
|
||||||
# -v $(pwd)/output:/workspace/output \
|
|
||||||
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
|
|
||||||
# -v /zfs/models/hf:/models \
|
|
||||||
# bill-finetune \
|
|
||||||
# --dataset /workspace/dataset.jsonl \
|
|
||||||
# --output-dir /workspace/output/qwen-bill-summarizer
|
|
||||||
|
|
||||||
FROM ghcr.io/unslothai/unsloth:latest
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir typer
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
|
||||||
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
|
||||||
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
|
||||||
COPY python/__init__.py python/__init__.py
|
|
||||||
|
|
||||||
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"]
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1 +0,0 @@
|
|||||||
how many oceans are there in the world
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
whos the president of the united states
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
whats the greatest country in the world
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
was/is the usa the greatest country in the world
|
|
||||||
Reference in New Issue
Block a user