diff --git a/python/prompt_bench/downloader.py b/python/prompt_bench/downloader.py index ad7fba7..8710b9e 100644 --- a/python/prompt_bench/downloader.py +++ b/python/prompt_bench/downloader.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import tomllib from pathlib import Path from typing import Annotated @@ -63,10 +62,7 @@ def main( message = f"Config file does not exist: {config}" raise typer.BadParameter(message) - with config.open("rb") as file: - raw = tomllib.load(file) - - benchmark_config = BenchmarkConfig(**raw) + benchmark_config = BenchmarkConfig.from_toml(config) download_all(benchmark_config) diff --git a/python/prompt_bench/main.py b/python/prompt_bench/main.py index 50bd04e..2d6a725 100644 --- a/python/prompt_bench/main.py +++ b/python/prompt_bench/main.py @@ -5,7 +5,6 @@ from __future__ import annotations import json import logging import time -import tomllib from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Annotated @@ -201,10 +200,7 @@ def main( message = f"Config file does not exist: {config}" raise typer.BadParameter(message) - with config.open("rb") as file: - raw = tomllib.load(file) - - benchmark_config = BenchmarkConfig(**raw) + benchmark_config = BenchmarkConfig.from_toml(config) output_dir.mkdir(parents=True, exist_ok=True) run_benchmark(benchmark_config, input_dir, output_dir) diff --git a/python/prompt_bench/models.py b/python/prompt_bench/models.py index 1abae25..c722aba 100644 --- a/python/prompt_bench/models.py +++ b/python/prompt_bench/models.py @@ -2,8 +2,14 @@ from __future__ import annotations +import tomllib +from typing import TYPE_CHECKING + from pydantic import BaseModel +if TYPE_CHECKING: + from pathlib import Path + class BenchmarkConfig(BaseModel): """Top-level benchmark configuration loaded from TOML.""" @@ -16,3 +22,9 @@ class BenchmarkConfig(BaseModel): timeout: int = 300 concurrency: int = 4 vllm_startup_timeout: int = 900 + + @classmethod + def from_toml(cls, config_path: Path) -> BenchmarkConfig: + """Load benchmark config from a TOML file.""" + raw = tomllib.loads(config_path.read_text())["bench"] + return cls(**raw)