mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
creating prompt_bench downloader
This commit is contained in:
@@ -24,6 +24,7 @@
|
|||||||
fastapi
|
fastapi
|
||||||
fastapi-cli
|
fastapi-cli
|
||||||
httpx
|
httpx
|
||||||
|
huggingface-hub
|
||||||
mypy
|
mypy
|
||||||
orjson
|
orjson
|
||||||
polars
|
polars
|
||||||
|
|||||||
1
python/prompt_bench/__init__.py
Normal file
1
python/prompt_bench/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||||
79
python/prompt_bench/downloader.py
Normal file
79
python/prompt_bench/downloader.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""HuggingFace model downloader."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from python.prompt_bench.models import BenchmarkConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def local_model_path(repo: str, model_dir: str) -> Path:
|
||||||
|
"""Return the local directory path for a HuggingFace repo."""
|
||||||
|
return Path(model_dir) / repo
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_present(repo: str, model_dir: str) -> bool:
|
||||||
|
"""Check if a model has already been downloaded."""
|
||||||
|
path = local_model_path(repo, model_dir)
|
||||||
|
return path.exists() and any(path.iterdir())
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(repo: str, model_dir: str) -> Path:
|
||||||
|
"""Download a HuggingFace model to the local model directory.
|
||||||
|
|
||||||
|
Skips the download if the model directory already exists and contains files.
|
||||||
|
"""
|
||||||
|
local_path = local_model_path(repo, model_dir)
|
||||||
|
|
||||||
|
if is_model_present(repo, model_dir):
|
||||||
|
logger.info("Model already exists: %s", local_path)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
logger.info("Downloading model: %s -> %s", repo, local_path)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo,
|
||||||
|
local_dir=str(local_path),
|
||||||
|
)
|
||||||
|
logger.info("Download complete: %s", repo)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_all(config: BenchmarkConfig) -> None:
|
||||||
|
"""Download every model listed in the config, top to bottom."""
|
||||||
|
for repo in config.models:
|
||||||
|
download_model(repo, config.model_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Download all models listed in the benchmark config."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
if not config.is_file():
|
||||||
|
message = f"Config file does not exist: {config}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
with config.open("rb") as file:
|
||||||
|
raw = tomllib.load(file)
|
||||||
|
|
||||||
|
benchmark_config = BenchmarkConfig(**raw)
|
||||||
|
download_all(benchmark_config)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
17
python/prompt_bench/models.py
Normal file
17
python/prompt_bench/models.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Pydantic models for benchmark configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(BaseModel):
|
||||||
|
"""Top-level benchmark configuration loaded from TOML."""
|
||||||
|
|
||||||
|
models: list[str]
|
||||||
|
model_dir: str = "/zfs/models/hf"
|
||||||
|
port: int = 8000
|
||||||
|
gpu_memory_utilization: float = 0.90
|
||||||
|
temperature: float = 0.0
|
||||||
|
timeout: int = 300
|
||||||
|
concurrency: int = 4
|
||||||
Reference in New Issue
Block a user