76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
"""HuggingFace model downloader."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
import typer
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from python.prompt_bench.models import BenchmarkConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def local_model_path(repo: str, model_dir: str) -> Path:
|
|
"""Return the local directory path for a HuggingFace repo."""
|
|
return Path(model_dir) / repo
|
|
|
|
|
|
def is_model_present(repo: str, model_dir: str) -> bool:
|
|
"""Check if a model has already been downloaded."""
|
|
path = local_model_path(repo, model_dir)
|
|
return path.exists() and any(path.iterdir())
|
|
|
|
|
|
def download_model(repo: str, model_dir: str) -> Path:
|
|
"""Download a HuggingFace model to the local model directory.
|
|
|
|
Skips the download if the model directory already exists and contains files.
|
|
"""
|
|
local_path = local_model_path(repo, model_dir)
|
|
|
|
if is_model_present(repo, model_dir):
|
|
logger.info("Model already exists: %s", local_path)
|
|
return local_path
|
|
|
|
logger.info("Downloading model: %s -> %s", repo, local_path)
|
|
snapshot_download(
|
|
repo_id=repo,
|
|
local_dir=str(local_path),
|
|
)
|
|
logger.info("Download complete: %s", repo)
|
|
return local_path
|
|
|
|
|
|
def download_all(config: BenchmarkConfig) -> None:
|
|
"""Download every model listed in the config, top to bottom."""
|
|
for repo in config.models:
|
|
download_model(repo, config.model_dir)
|
|
|
|
|
|
def main(
|
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
|
) -> None:
|
|
"""Download all models listed in the benchmark config."""
|
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
|
|
if not config.is_file():
|
|
message = f"Config file does not exist: {config}"
|
|
raise typer.BadParameter(message)
|
|
|
|
benchmark_config = BenchmarkConfig.from_toml(config)
|
|
download_all(benchmark_config)
|
|
|
|
|
|
def cli() -> None:
|
|
"""Typer entry point."""
|
|
typer.run(main)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|