created main prompt bench

This commit is contained in:
2026-04-08 09:08:25 -04:00
committed by ForgeCode
parent d358f0fbec
commit d17c883476
4 changed files with 368 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
"""Docker container lifecycle management for vLLM."""
from __future__ import annotations
import logging
import subprocess
logger = logging.getLogger(__name__)
CONTAINER_NAME = "vllm-bench"
VLLM_IMAGE = "vllm/vllm-openai:v0.8.5"
def start_vllm(
*,
model: str,
port: int,
model_dir: str,
gpu_memory_utilization: float,
) -> None:
"""Start a vLLM container serving the given model.
Args:
model: HuggingFace model directory name (relative to model_dir).
port: Host port to bind.
model_dir: Host path containing HuggingFace model directories.
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
"""
command = [
"docker",
"run",
"-d",
"--name",
CONTAINER_NAME,
"--device=nvidia.com/gpu=all",
"--ipc=host",
"-v",
f"{model_dir}:/models",
"-p",
f"{port}:8000",
VLLM_IMAGE,
"--model",
f"/models/{model}",
"--served-model-name",
model,
"--gpu-memory-utilization",
str(gpu_memory_utilization),
"--max-model-len",
"4096",
]
logger.info("Starting vLLM container with model: %s", model)
result = subprocess.run(command, capture_output=True, text=True, check=False)
if result.returncode != 0:
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
raise RuntimeError(msg)
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
def stop_vllm() -> None:
"""Stop and remove the vLLM benchmark container."""
logger.info("Stopping vLLM container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", CONTAINER_NAME], capture_output=True, check=False)
logger.info("vLLM container stopped and removed")
def check_gpu_free() -> None:
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
result = subprocess.run(
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
return
processes = result.stdout.strip()
if processes:
logger.warning("GPU processes detected:\n%s", processes)
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")

218
python/prompt_bench/main.py Normal file
View File

@@ -0,0 +1,218 @@
"""CLI entry point for the prompt benchmarking system."""
from __future__ import annotations
import json
import logging
import time
import tomllib
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Annotated
import typer
from python.prompt_bench.container import check_gpu_free, start_vllm, stop_vllm
from python.prompt_bench.downloader import is_model_present
from python.prompt_bench.models import BenchmarkConfig
from python.prompt_bench.vllm_client import VLLMClient
logger = logging.getLogger(__name__)
def discover_prompts(input_dir: Path) -> list[Path]:
"""Find all .txt files in the input directory."""
prompts = list(input_dir.glob("*.txt"))
if not prompts:
message = f"No .txt files found in {input_dir}"
raise FileNotFoundError(message)
return prompts
def _run_prompt(
client: VLLMClient,
prompt_path: Path,
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
) -> tuple[bool, float]:
"""Run a single prompt. Returns (success, elapsed_seconds)."""
filename = prompt_path.name
output_path = model_output / filename
start = time.monotonic()
try:
prompt_text = prompt_path.read_text()
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
output_path.write_text(response)
elapsed = time.monotonic() - start
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
except Exception:
elapsed = time.monotonic() - start
error_path = model_output / f"{filename}.error"
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
error_path.write_text(f"Error processing {filename}")
return False, elapsed
return True, elapsed
def benchmark_model(
client: VLLMClient,
prompts: list[Path],
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
concurrency: int,
) -> tuple[int, int]:
"""Run all prompts against a single model in parallel.
vLLM batches concurrent requests internally, so submitting many at once is
significantly faster than running them serially.
"""
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
skipped = len(prompts) - len(pending)
if skipped:
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
if not pending:
logger.info("Nothing to do for %s", repo)
return 0, 0
completed = 0
failed = 0
latencies: list[float] = []
wall_start = time.monotonic()
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = [
executor.submit(
_run_prompt,
client,
prompt_path,
repo=repo,
model_dir_name=model_dir_name,
model_output=model_output,
temperature=temperature,
)
for prompt_path in pending
]
for future in as_completed(futures):
success, elapsed = future.result()
latencies.append(elapsed)
if success:
completed += 1
else:
failed += 1
wall_elapsed = time.monotonic() - wall_start
attempted = completed + failed
avg_latency = sum(latencies) / attempted
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
timing = {
"repo": repo,
"wall_seconds": wall_elapsed,
"attempted": attempted,
"completed": completed,
"failed": failed,
"avg_latency_seconds": avg_latency,
"throughput_prompts_per_second": throughput,
"concurrency": concurrency,
}
timing_path = model_output / "_timing.json"
timing_path.write_text(json.dumps(timing, indent=2))
return completed, failed
def run_benchmark(
config: BenchmarkConfig,
input_dir: Path,
output_dir: Path,
) -> None:
"""Execute the benchmark across all models and prompts."""
prompts = discover_prompts(input_dir)
logger.info("Found %d prompts in %s", len(prompts), input_dir)
check_gpu_free()
total_completed = 0
total_failed = 0
for repo in config.models:
if not is_model_present(repo, config.model_dir):
logger.warning("Skipping (not downloaded): %s", repo)
continue
model_output = output_dir / repo
model_output.mkdir(parents=True, exist_ok=True)
logger.info("=== Benchmarking model: %s ===", repo)
stop_vllm()
try:
start_vllm(
model=repo,
port=config.port,
model_dir=config.model_dir,
gpu_memory_utilization=config.gpu_memory_utilization,
)
except RuntimeError:
logger.exception("Failed to start vLLM for %s, skipping", repo)
continue
logger.info("vLLM started for %s", repo)
try:
with VLLMClient(port=config.port, timeout=config.timeout) as client:
client.wait_ready(max_wait=config.vllm_startup_timeout)
completed, failed = benchmark_model(
client,
prompts,
repo=repo,
model_dir_name=repo,
model_output=model_output,
temperature=config.temperature,
concurrency=config.concurrency,
)
total_completed += completed
total_failed += failed
finally:
stop_vllm()
logger.info("=== Benchmark complete ===")
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
def main(
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run prompts through multiple LLMs via vLLM and save results."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not input_dir.is_dir():
message = f"Input directory does not exist: {input_dir}"
raise typer.BadParameter(message)
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
with config.open("rb") as file:
raw = tomllib.load(file)
benchmark_config = BenchmarkConfig(**raw)
output_dir.mkdir(parents=True, exist_ok=True)
run_benchmark(benchmark_config, input_dir, output_dir)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()

View File

@@ -15,3 +15,4 @@ class BenchmarkConfig(BaseModel):
temperature: float = 0.0 temperature: float = 0.0
timeout: int = 300 timeout: int = 300
concurrency: int = 4 concurrency: int = 4
vllm_startup_timeout: int = 900

View File

@@ -0,0 +1,68 @@
"""OpenAI-compatible client for vLLM's API."""
from __future__ import annotations
import logging
import time
from typing import Self
import httpx
logger = logging.getLogger(__name__)
READY_POLL_INTERVAL = 2.0
class VLLMClient:
"""Talk to a vLLM server via its OpenAI-compatible API.
Args:
host: vLLM host.
port: vLLM port.
timeout: Per-request timeout in seconds.
"""
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
"""Create a client connected to a vLLM server."""
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
def wait_ready(self, max_wait: int) -> None:
"""Poll /v1/models until the server is ready or timeout."""
deadline = time.monotonic() + max_wait
while time.monotonic() < deadline:
try:
response = self._client.get("/v1/models")
if response.is_success:
logger.info("vLLM server is ready")
return
except httpx.TransportError:
pass
time.sleep(READY_POLL_INTERVAL)
msg = f"vLLM server not ready after {max_wait}s"
raise TimeoutError(msg)
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
"""Send a prompt to /v1/completions and return the response text."""
payload = {
"model": model,
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
}
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
response = self._client.post("/v1/completions", json=payload)
response.raise_for_status()
data = response.json()
return data["choices"][0]["text"]
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client on exit."""
self.close()