mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 21:18:18 -04:00
created main prompt bench
This commit is contained in:
81
python/prompt_bench/container.py
Normal file
81
python/prompt_bench/container.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Docker container lifecycle management for vLLM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONTAINER_NAME = "vllm-bench"
|
||||||
|
VLLM_IMAGE = "vllm/vllm-openai:v0.8.5"
|
||||||
|
|
||||||
|
|
||||||
|
def start_vllm(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
port: int,
|
||||||
|
model_dir: str,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
) -> None:
|
||||||
|
"""Start a vLLM container serving the given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: HuggingFace model directory name (relative to model_dir).
|
||||||
|
port: Host port to bind.
|
||||||
|
model_dir: Host path containing HuggingFace model directories.
|
||||||
|
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
|
||||||
|
"""
|
||||||
|
command = [
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"-d",
|
||||||
|
"--name",
|
||||||
|
CONTAINER_NAME,
|
||||||
|
"--device=nvidia.com/gpu=all",
|
||||||
|
"--ipc=host",
|
||||||
|
"-v",
|
||||||
|
f"{model_dir}:/models",
|
||||||
|
"-p",
|
||||||
|
f"{port}:8000",
|
||||||
|
VLLM_IMAGE,
|
||||||
|
"--model",
|
||||||
|
f"/models/{model}",
|
||||||
|
"--served-model-name",
|
||||||
|
model,
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
str(gpu_memory_utilization),
|
||||||
|
"--max-model-len",
|
||||||
|
"4096",
|
||||||
|
]
|
||||||
|
logger.info("Starting vLLM container with model: %s", model)
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True, check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
|
||||||
|
|
||||||
|
|
||||||
|
def stop_vllm() -> None:
|
||||||
|
"""Stop and remove the vLLM benchmark container."""
|
||||||
|
logger.info("Stopping vLLM container")
|
||||||
|
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
subprocess.run(["docker", "rm", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
logger.info("vLLM container stopped and removed")
|
||||||
|
|
||||||
|
|
||||||
|
def check_gpu_free() -> None:
|
||||||
|
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
|
||||||
|
result = subprocess.run(
|
||||||
|
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
|
||||||
|
return
|
||||||
|
processes = result.stdout.strip()
|
||||||
|
if processes:
|
||||||
|
logger.warning("GPU processes detected:\n%s", processes)
|
||||||
|
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")
|
||||||
218
python/prompt_bench/main.py
Normal file
218
python/prompt_bench/main.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""CLI entry point for the prompt benchmarking system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import tomllib
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from python.prompt_bench.container import check_gpu_free, start_vllm, stop_vllm
|
||||||
|
from python.prompt_bench.downloader import is_model_present
|
||||||
|
from python.prompt_bench.models import BenchmarkConfig
|
||||||
|
from python.prompt_bench.vllm_client import VLLMClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_prompts(input_dir: Path) -> list[Path]:
|
||||||
|
"""Find all .txt files in the input directory."""
|
||||||
|
prompts = list(input_dir.glob("*.txt"))
|
||||||
|
if not prompts:
|
||||||
|
message = f"No .txt files found in {input_dir}"
|
||||||
|
raise FileNotFoundError(message)
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def _run_prompt(
|
||||||
|
client: VLLMClient,
|
||||||
|
prompt_path: Path,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
model_dir_name: str,
|
||||||
|
model_output: Path,
|
||||||
|
temperature: float,
|
||||||
|
) -> tuple[bool, float]:
|
||||||
|
"""Run a single prompt. Returns (success, elapsed_seconds)."""
|
||||||
|
filename = prompt_path.name
|
||||||
|
output_path = model_output / filename
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
prompt_text = prompt_path.read_text()
|
||||||
|
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
|
||||||
|
output_path.write_text(response)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
|
||||||
|
except Exception:
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
error_path = model_output / f"{filename}.error"
|
||||||
|
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
|
||||||
|
error_path.write_text(f"Error processing {filename}")
|
||||||
|
return False, elapsed
|
||||||
|
return True, elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_model(
|
||||||
|
client: VLLMClient,
|
||||||
|
prompts: list[Path],
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
model_dir_name: str,
|
||||||
|
model_output: Path,
|
||||||
|
temperature: float,
|
||||||
|
concurrency: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""Run all prompts against a single model in parallel.
|
||||||
|
|
||||||
|
vLLM batches concurrent requests internally, so submitting many at once is
|
||||||
|
significantly faster than running them serially.
|
||||||
|
"""
|
||||||
|
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
|
||||||
|
skipped = len(prompts) - len(pending)
|
||||||
|
if skipped:
|
||||||
|
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.info("Nothing to do for %s", repo)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
completed = 0
|
||||||
|
failed = 0
|
||||||
|
latencies: list[float] = []
|
||||||
|
|
||||||
|
wall_start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
_run_prompt,
|
||||||
|
client,
|
||||||
|
prompt_path,
|
||||||
|
repo=repo,
|
||||||
|
model_dir_name=model_dir_name,
|
||||||
|
model_output=model_output,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
for prompt_path in pending
|
||||||
|
]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
success, elapsed = future.result()
|
||||||
|
latencies.append(elapsed)
|
||||||
|
if success:
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
wall_elapsed = time.monotonic() - wall_start
|
||||||
|
|
||||||
|
attempted = completed + failed
|
||||||
|
avg_latency = sum(latencies) / attempted
|
||||||
|
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
|
||||||
|
timing = {
|
||||||
|
"repo": repo,
|
||||||
|
"wall_seconds": wall_elapsed,
|
||||||
|
"attempted": attempted,
|
||||||
|
"completed": completed,
|
||||||
|
"failed": failed,
|
||||||
|
"avg_latency_seconds": avg_latency,
|
||||||
|
"throughput_prompts_per_second": throughput,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
}
|
||||||
|
timing_path = model_output / "_timing.json"
|
||||||
|
timing_path.write_text(json.dumps(timing, indent=2))
|
||||||
|
|
||||||
|
return completed, failed
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
input_dir: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Execute the benchmark across all models and prompts."""
|
||||||
|
prompts = discover_prompts(input_dir)
|
||||||
|
logger.info("Found %d prompts in %s", len(prompts), input_dir)
|
||||||
|
|
||||||
|
check_gpu_free()
|
||||||
|
|
||||||
|
total_completed = 0
|
||||||
|
total_failed = 0
|
||||||
|
|
||||||
|
for repo in config.models:
|
||||||
|
if not is_model_present(repo, config.model_dir):
|
||||||
|
logger.warning("Skipping (not downloaded): %s", repo)
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_output = output_dir / repo
|
||||||
|
model_output.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("=== Benchmarking model: %s ===", repo)
|
||||||
|
|
||||||
|
stop_vllm()
|
||||||
|
try:
|
||||||
|
start_vllm(
|
||||||
|
model=repo,
|
||||||
|
port=config.port,
|
||||||
|
model_dir=config.model_dir,
|
||||||
|
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.exception("Failed to start vLLM for %s, skipping", repo)
|
||||||
|
continue
|
||||||
|
logger.info("vLLM started for %s", repo)
|
||||||
|
try:
|
||||||
|
with VLLMClient(port=config.port, timeout=config.timeout) as client:
|
||||||
|
client.wait_ready(max_wait=config.vllm_startup_timeout)
|
||||||
|
completed, failed = benchmark_model(
|
||||||
|
client,
|
||||||
|
prompts,
|
||||||
|
repo=repo,
|
||||||
|
model_dir_name=repo,
|
||||||
|
model_output=model_output,
|
||||||
|
temperature=config.temperature,
|
||||||
|
concurrency=config.concurrency,
|
||||||
|
)
|
||||||
|
total_completed += completed
|
||||||
|
total_failed += failed
|
||||||
|
finally:
|
||||||
|
stop_vllm()
|
||||||
|
|
||||||
|
logger.info("=== Benchmark complete ===")
|
||||||
|
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
|
||||||
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||||
|
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run prompts through multiple LLMs via vLLM and save results."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
if not input_dir.is_dir():
|
||||||
|
message = f"Input directory does not exist: {input_dir}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
if not config.is_file():
|
||||||
|
message = f"Config file does not exist: {config}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
with config.open("rb") as file:
|
||||||
|
raw = tomllib.load(file)
|
||||||
|
|
||||||
|
benchmark_config = BenchmarkConfig(**raw)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
run_benchmark(benchmark_config, input_dir, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -15,3 +15,4 @@ class BenchmarkConfig(BaseModel):
|
|||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
timeout: int = 300
|
timeout: int = 300
|
||||||
concurrency: int = 4
|
concurrency: int = 4
|
||||||
|
vllm_startup_timeout: int = 900
|
||||||
|
|||||||
68
python/prompt_bench/vllm_client.py
Normal file
68
python/prompt_bench/vllm_client.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""OpenAI-compatible client for vLLM's API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
READY_POLL_INTERVAL = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMClient:
|
||||||
|
"""Talk to a vLLM server via its OpenAI-compatible API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: vLLM host.
|
||||||
|
port: vLLM port.
|
||||||
|
timeout: Per-request timeout in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
|
||||||
|
"""Create a client connected to a vLLM server."""
|
||||||
|
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
|
||||||
|
|
||||||
|
def wait_ready(self, max_wait: int) -> None:
|
||||||
|
"""Poll /v1/models until the server is ready or timeout."""
|
||||||
|
deadline = time.monotonic() + max_wait
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
response = self._client.get("/v1/models")
|
||||||
|
if response.is_success:
|
||||||
|
logger.info("vLLM server is ready")
|
||||||
|
return
|
||||||
|
except httpx.TransportError:
|
||||||
|
pass
|
||||||
|
time.sleep(READY_POLL_INTERVAL)
|
||||||
|
msg = f"vLLM server not ready after {max_wait}s"
|
||||||
|
raise TimeoutError(msg)
|
||||||
|
|
||||||
|
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
|
||||||
|
"""Send a prompt to /v1/completions and return the response text."""
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
|
||||||
|
response = self._client.post("/v1/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data["choices"][0]["text"]
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""Enter the context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: object) -> None:
|
||||||
|
"""Close the HTTP client on exit."""
|
||||||
|
self.close()
|
||||||
Reference in New Issue
Block a user