start
This commit is contained in:
215
prompt_bench/main.py
Normal file
215
prompt_bench/main.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""CLI entry point for the prompt benchmarking system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from python.prompt_bench.containers.lib import check_gpu_free
|
||||
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm
|
||||
from python.prompt_bench.downloader import is_model_present
|
||||
from python.prompt_bench.models import BenchmarkConfig
|
||||
from python.prompt_bench.vllm_client import VLLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_prompts(input_dir: Path) -> list[Path]:
|
||||
"""Find all .txt files in the input directory."""
|
||||
prompts = list(input_dir.glob("*.txt"))
|
||||
if not prompts:
|
||||
message = f"No .txt files found in {input_dir}"
|
||||
raise FileNotFoundError(message)
|
||||
return prompts
|
||||
|
||||
|
||||
def _run_prompt(
|
||||
client: VLLMClient,
|
||||
prompt_path: Path,
|
||||
*,
|
||||
repo: str,
|
||||
model_dir_name: str,
|
||||
model_output: Path,
|
||||
temperature: float,
|
||||
) -> tuple[bool, float]:
|
||||
"""Run a single prompt. Returns (success, elapsed_seconds)."""
|
||||
filename = prompt_path.name
|
||||
output_path = model_output / filename
|
||||
start = time.monotonic()
|
||||
try:
|
||||
prompt_text = prompt_path.read_text()
|
||||
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
|
||||
output_path.write_text(response)
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
|
||||
except Exception:
|
||||
elapsed = time.monotonic() - start
|
||||
error_path = model_output / f"{filename}.error"
|
||||
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
|
||||
error_path.write_text(f"Error processing {filename}")
|
||||
return False, elapsed
|
||||
return True, elapsed
|
||||
|
||||
|
||||
def benchmark_model(
|
||||
client: VLLMClient,
|
||||
prompts: list[Path],
|
||||
*,
|
||||
repo: str,
|
||||
model_dir_name: str,
|
||||
model_output: Path,
|
||||
temperature: float,
|
||||
concurrency: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Run all prompts against a single model in parallel.
|
||||
|
||||
vLLM batches concurrent requests internally, so submitting many at once is
|
||||
significantly faster than running them serially.
|
||||
"""
|
||||
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
|
||||
skipped = len(prompts) - len(pending)
|
||||
if skipped:
|
||||
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
||||
|
||||
if not pending:
|
||||
logger.info("Nothing to do for %s", repo)
|
||||
return 0, 0
|
||||
|
||||
completed = 0
|
||||
failed = 0
|
||||
latencies: list[float] = []
|
||||
|
||||
wall_start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
_run_prompt,
|
||||
client,
|
||||
prompt_path,
|
||||
repo=repo,
|
||||
model_dir_name=model_dir_name,
|
||||
model_output=model_output,
|
||||
temperature=temperature,
|
||||
)
|
||||
for prompt_path in pending
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
success, elapsed = future.result()
|
||||
latencies.append(elapsed)
|
||||
if success:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
wall_elapsed = time.monotonic() - wall_start
|
||||
|
||||
attempted = completed + failed
|
||||
avg_latency = sum(latencies) / attempted
|
||||
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
|
||||
timing = {
|
||||
"repo": repo,
|
||||
"wall_seconds": wall_elapsed,
|
||||
"attempted": attempted,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"avg_latency_seconds": avg_latency,
|
||||
"throughput_prompts_per_second": throughput,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
timing_path = model_output / "_timing.json"
|
||||
timing_path.write_text(json.dumps(timing, indent=2))
|
||||
|
||||
return completed, failed
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
config: BenchmarkConfig,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
) -> None:
|
||||
"""Execute the benchmark across all models and prompts."""
|
||||
prompts = discover_prompts(input_dir)
|
||||
logger.info("Found %d prompts in %s", len(prompts), input_dir)
|
||||
|
||||
check_gpu_free()
|
||||
|
||||
total_completed = 0
|
||||
total_failed = 0
|
||||
|
||||
for repo in config.models:
|
||||
if not is_model_present(repo, config.model_dir):
|
||||
logger.warning("Skipping (not downloaded): %s", repo)
|
||||
continue
|
||||
|
||||
model_output = output_dir / repo
|
||||
model_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("=== Benchmarking model: %s ===", repo)
|
||||
|
||||
stop_vllm()
|
||||
try:
|
||||
start_vllm(
|
||||
model=repo,
|
||||
port=config.port,
|
||||
model_dir=config.model_dir,
|
||||
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||
)
|
||||
except RuntimeError:
|
||||
logger.exception("Failed to start vLLM for %s, skipping", repo)
|
||||
continue
|
||||
logger.info("vLLM started for %s", repo)
|
||||
try:
|
||||
with VLLMClient(port=config.port, timeout=config.timeout) as client:
|
||||
client.wait_ready(max_wait=config.vllm_startup_timeout)
|
||||
completed, failed = benchmark_model(
|
||||
client,
|
||||
prompts,
|
||||
repo=repo,
|
||||
model_dir_name=repo,
|
||||
model_output=model_output,
|
||||
temperature=config.temperature,
|
||||
concurrency=config.concurrency,
|
||||
)
|
||||
total_completed += completed
|
||||
total_failed += failed
|
||||
finally:
|
||||
stop_vllm()
|
||||
|
||||
logger.info("=== Benchmark complete ===")
|
||||
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
|
||||
|
||||
|
||||
def main(
|
||||
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
|
||||
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Run prompts through multiple LLMs via vLLM and save results."""
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
if not input_dir.is_dir():
|
||||
message = f"Input directory does not exist: {input_dir}"
|
||||
raise typer.BadParameter(message)
|
||||
if not config.is_file():
|
||||
message = f"Config file does not exist: {config}"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
benchmark_config = BenchmarkConfig.from_toml(config)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
run_benchmark(benchmark_config, input_dir, output_dir)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user