creating prompt_bench downloader

This commit is contained in:
2026-04-07 19:15:42 -04:00
committed by ForgeCode
parent af365fce9a
commit 9c8013d69d
4 changed files with 98 additions and 0 deletions

View File

@@ -0,0 +1,79 @@
"""HuggingFace model downloader."""
from __future__ import annotations
import logging
import tomllib
from pathlib import Path
from typing import Annotated
import typer
from huggingface_hub import snapshot_download
from python.prompt_bench.models import BenchmarkConfig
logger = logging.getLogger(__name__)
def local_model_path(repo: str, model_dir: str) -> Path:
"""Return the local directory path for a HuggingFace repo."""
return Path(model_dir) / repo
def is_model_present(repo: str, model_dir: str) -> bool:
"""Check if a model has already been downloaded."""
path = local_model_path(repo, model_dir)
return path.exists() and any(path.iterdir())
def download_model(repo: str, model_dir: str) -> Path:
"""Download a HuggingFace model to the local model directory.
Skips the download if the model directory already exists and contains files.
"""
local_path = local_model_path(repo, model_dir)
if is_model_present(repo, model_dir):
logger.info("Model already exists: %s", local_path)
return local_path
logger.info("Downloading model: %s -> %s", repo, local_path)
snapshot_download(
repo_id=repo,
local_dir=str(local_path),
)
logger.info("Download complete: %s", repo)
return local_path
def download_all(config: BenchmarkConfig) -> None:
"""Download every model listed in the config, top to bottom."""
for repo in config.models:
download_model(repo, config.model_dir)
def main(
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Download all models listed in the benchmark config."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
with config.open("rb") as file:
raw = tomllib.load(file)
benchmark_config = BenchmarkConfig(**raw)
download_all(benchmark_config)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()