mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
created main prompt bench
This commit is contained in:
81
python/prompt_bench/container.py
Normal file
81
python/prompt_bench/container.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Docker container lifecycle management for vLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CONTAINER_NAME = "vllm-bench"
|
||||
VLLM_IMAGE = "vllm/vllm-openai:v0.8.5"
|
||||
|
||||
|
||||
def start_vllm(
|
||||
*,
|
||||
model: str,
|
||||
port: int,
|
||||
model_dir: str,
|
||||
gpu_memory_utilization: float,
|
||||
) -> None:
|
||||
"""Start a vLLM container serving the given model.
|
||||
|
||||
Args:
|
||||
model: HuggingFace model directory name (relative to model_dir).
|
||||
port: Host port to bind.
|
||||
model_dir: Host path containing HuggingFace model directories.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
|
||||
"""
|
||||
command = [
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
CONTAINER_NAME,
|
||||
"--device=nvidia.com/gpu=all",
|
||||
"--ipc=host",
|
||||
"-v",
|
||||
f"{model_dir}:/models",
|
||||
"-p",
|
||||
f"{port}:8000",
|
||||
VLLM_IMAGE,
|
||||
"--model",
|
||||
f"/models/{model}",
|
||||
"--served-model-name",
|
||||
model,
|
||||
"--gpu-memory-utilization",
|
||||
str(gpu_memory_utilization),
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
]
|
||||
logger.info("Starting vLLM container with model: %s", model)
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0:
|
||||
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
|
||||
raise RuntimeError(msg)
|
||||
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
|
||||
|
||||
|
||||
def stop_vllm() -> None:
|
||||
"""Stop and remove the vLLM benchmark container."""
|
||||
logger.info("Stopping vLLM container")
|
||||
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
||||
subprocess.run(["docker", "rm", CONTAINER_NAME], capture_output=True, check=False)
|
||||
logger.info("vLLM container stopped and removed")
|
||||
|
||||
|
||||
def check_gpu_free() -> None:
|
||||
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
|
||||
return
|
||||
processes = result.stdout.strip()
|
||||
if processes:
|
||||
logger.warning("GPU processes detected:\n%s", processes)
|
||||
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")
|
||||
218
python/prompt_bench/main.py
Normal file
218
python/prompt_bench/main.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""CLI entry point for the prompt benchmarking system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import tomllib
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from python.prompt_bench.container import check_gpu_free, start_vllm, stop_vllm
|
||||
from python.prompt_bench.downloader import is_model_present
|
||||
from python.prompt_bench.models import BenchmarkConfig
|
||||
from python.prompt_bench.vllm_client import VLLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_prompts(input_dir: Path) -> list[Path]:
|
||||
"""Find all .txt files in the input directory."""
|
||||
prompts = list(input_dir.glob("*.txt"))
|
||||
if not prompts:
|
||||
message = f"No .txt files found in {input_dir}"
|
||||
raise FileNotFoundError(message)
|
||||
return prompts
|
||||
|
||||
|
||||
def _run_prompt(
|
||||
client: VLLMClient,
|
||||
prompt_path: Path,
|
||||
*,
|
||||
repo: str,
|
||||
model_dir_name: str,
|
||||
model_output: Path,
|
||||
temperature: float,
|
||||
) -> tuple[bool, float]:
|
||||
"""Run a single prompt. Returns (success, elapsed_seconds)."""
|
||||
filename = prompt_path.name
|
||||
output_path = model_output / filename
|
||||
start = time.monotonic()
|
||||
try:
|
||||
prompt_text = prompt_path.read_text()
|
||||
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
|
||||
output_path.write_text(response)
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
|
||||
except Exception:
|
||||
elapsed = time.monotonic() - start
|
||||
error_path = model_output / f"{filename}.error"
|
||||
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
|
||||
error_path.write_text(f"Error processing {filename}")
|
||||
return False, elapsed
|
||||
return True, elapsed
|
||||
|
||||
|
||||
def benchmark_model(
|
||||
client: VLLMClient,
|
||||
prompts: list[Path],
|
||||
*,
|
||||
repo: str,
|
||||
model_dir_name: str,
|
||||
model_output: Path,
|
||||
temperature: float,
|
||||
concurrency: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Run all prompts against a single model in parallel.
|
||||
|
||||
vLLM batches concurrent requests internally, so submitting many at once is
|
||||
significantly faster than running them serially.
|
||||
"""
|
||||
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
|
||||
skipped = len(prompts) - len(pending)
|
||||
if skipped:
|
||||
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
||||
|
||||
if not pending:
|
||||
logger.info("Nothing to do for %s", repo)
|
||||
return 0, 0
|
||||
|
||||
completed = 0
|
||||
failed = 0
|
||||
latencies: list[float] = []
|
||||
|
||||
wall_start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
_run_prompt,
|
||||
client,
|
||||
prompt_path,
|
||||
repo=repo,
|
||||
model_dir_name=model_dir_name,
|
||||
model_output=model_output,
|
||||
temperature=temperature,
|
||||
)
|
||||
for prompt_path in pending
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
success, elapsed = future.result()
|
||||
latencies.append(elapsed)
|
||||
if success:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
wall_elapsed = time.monotonic() - wall_start
|
||||
|
||||
attempted = completed + failed
|
||||
avg_latency = sum(latencies) / attempted
|
||||
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
|
||||
timing = {
|
||||
"repo": repo,
|
||||
"wall_seconds": wall_elapsed,
|
||||
"attempted": attempted,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"avg_latency_seconds": avg_latency,
|
||||
"throughput_prompts_per_second": throughput,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
timing_path = model_output / "_timing.json"
|
||||
timing_path.write_text(json.dumps(timing, indent=2))
|
||||
|
||||
return completed, failed
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
config: BenchmarkConfig,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
) -> None:
|
||||
"""Execute the benchmark across all models and prompts."""
|
||||
prompts = discover_prompts(input_dir)
|
||||
logger.info("Found %d prompts in %s", len(prompts), input_dir)
|
||||
|
||||
check_gpu_free()
|
||||
|
||||
total_completed = 0
|
||||
total_failed = 0
|
||||
|
||||
for repo in config.models:
|
||||
if not is_model_present(repo, config.model_dir):
|
||||
logger.warning("Skipping (not downloaded): %s", repo)
|
||||
continue
|
||||
|
||||
model_output = output_dir / repo
|
||||
model_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("=== Benchmarking model: %s ===", repo)
|
||||
|
||||
stop_vllm()
|
||||
try:
|
||||
start_vllm(
|
||||
model=repo,
|
||||
port=config.port,
|
||||
model_dir=config.model_dir,
|
||||
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||
)
|
||||
except RuntimeError:
|
||||
logger.exception("Failed to start vLLM for %s, skipping", repo)
|
||||
continue
|
||||
logger.info("vLLM started for %s", repo)
|
||||
try:
|
||||
with VLLMClient(port=config.port, timeout=config.timeout) as client:
|
||||
client.wait_ready(max_wait=config.vllm_startup_timeout)
|
||||
completed, failed = benchmark_model(
|
||||
client,
|
||||
prompts,
|
||||
repo=repo,
|
||||
model_dir_name=repo,
|
||||
model_output=model_output,
|
||||
temperature=config.temperature,
|
||||
concurrency=config.concurrency,
|
||||
)
|
||||
total_completed += completed
|
||||
total_failed += failed
|
||||
finally:
|
||||
stop_vllm()
|
||||
|
||||
logger.info("=== Benchmark complete ===")
|
||||
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
|
||||
|
||||
|
||||
def main(
|
||||
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
|
||||
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Run prompts through multiple LLMs via vLLM and save results."""
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
if not input_dir.is_dir():
|
||||
message = f"Input directory does not exist: {input_dir}"
|
||||
raise typer.BadParameter(message)
|
||||
if not config.is_file():
|
||||
message = f"Config file does not exist: {config}"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
with config.open("rb") as file:
|
||||
raw = tomllib.load(file)
|
||||
|
||||
benchmark_config = BenchmarkConfig(**raw)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
run_benchmark(benchmark_config, input_dir, output_dir)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -15,3 +15,4 @@ class BenchmarkConfig(BaseModel):
|
||||
temperature: float = 0.0
|
||||
timeout: int = 300
|
||||
concurrency: int = 4
|
||||
vllm_startup_timeout: int = 900
|
||||
|
||||
68
python/prompt_bench/vllm_client.py
Normal file
68
python/prompt_bench/vllm_client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""OpenAI-compatible client for vLLM's API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Self
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
READY_POLL_INTERVAL = 2.0
|
||||
|
||||
|
||||
class VLLMClient:
|
||||
"""Talk to a vLLM server via its OpenAI-compatible API.
|
||||
|
||||
Args:
|
||||
host: vLLM host.
|
||||
port: vLLM port.
|
||||
timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
|
||||
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
|
||||
"""Create a client connected to a vLLM server."""
|
||||
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
|
||||
|
||||
def wait_ready(self, max_wait: int) -> None:
|
||||
"""Poll /v1/models until the server is ready or timeout."""
|
||||
deadline = time.monotonic() + max_wait
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
response = self._client.get("/v1/models")
|
||||
if response.is_success:
|
||||
logger.info("vLLM server is ready")
|
||||
return
|
||||
except httpx.TransportError:
|
||||
pass
|
||||
time.sleep(READY_POLL_INTERVAL)
|
||||
msg = f"vLLM server not ready after {max_wait}s"
|
||||
raise TimeoutError(msg)
|
||||
|
||||
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
|
||||
"""Send a prompt to /v1/completions and return the response text."""
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
|
||||
response = self._client.post("/v1/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["text"]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
"""Close the HTTP client on exit."""
|
||||
self.close()
|
||||
Reference in New Issue
Block a user