mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 21:18:18 -04:00
creating prompt_bench downloader
This commit is contained in:
1
python/prompt_bench/__init__.py
Normal file
1
python/prompt_bench/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||
79
python/prompt_bench/downloader.py
Normal file
79
python/prompt_bench/downloader.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""HuggingFace model downloader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from python.prompt_bench.models import BenchmarkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def local_model_path(repo: str, model_dir: str) -> Path:
|
||||
"""Return the local directory path for a HuggingFace repo."""
|
||||
return Path(model_dir) / repo
|
||||
|
||||
|
||||
def is_model_present(repo: str, model_dir: str) -> bool:
|
||||
"""Check if a model has already been downloaded."""
|
||||
path = local_model_path(repo, model_dir)
|
||||
return path.exists() and any(path.iterdir())
|
||||
|
||||
|
||||
def download_model(repo: str, model_dir: str) -> Path:
|
||||
"""Download a HuggingFace model to the local model directory.
|
||||
|
||||
Skips the download if the model directory already exists and contains files.
|
||||
"""
|
||||
local_path = local_model_path(repo, model_dir)
|
||||
|
||||
if is_model_present(repo, model_dir):
|
||||
logger.info("Model already exists: %s", local_path)
|
||||
return local_path
|
||||
|
||||
logger.info("Downloading model: %s -> %s", repo, local_path)
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
local_dir=str(local_path),
|
||||
)
|
||||
logger.info("Download complete: %s", repo)
|
||||
return local_path
|
||||
|
||||
|
||||
def download_all(config: BenchmarkConfig) -> None:
|
||||
"""Download every model listed in the config, top to bottom."""
|
||||
for repo in config.models:
|
||||
download_model(repo, config.model_dir)
|
||||
|
||||
|
||||
def main(
|
||||
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Download all models listed in the benchmark config."""
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
if not config.is_file():
|
||||
message = f"Config file does not exist: {config}"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
with config.open("rb") as file:
|
||||
raw = tomllib.load(file)
|
||||
|
||||
benchmark_config = BenchmarkConfig(**raw)
|
||||
download_all(benchmark_config)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
17
python/prompt_bench/models.py
Normal file
17
python/prompt_bench/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Pydantic models for benchmark configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Top-level benchmark configuration loaded from TOML."""
|
||||
|
||||
models: list[str]
|
||||
model_dir: str = "/zfs/models/hf"
|
||||
port: int = 8000
|
||||
gpu_memory_utilization: float = 0.90
|
||||
temperature: float = 0.0
|
||||
timeout: int = 300
|
||||
concurrency: int = 4
|
||||
Reference in New Issue
Block a user