diff --git a/python/prompt_bench/container.py b/python/prompt_bench/container.py new file mode 100644 index 0000000..dc73fcc --- /dev/null +++ b/python/prompt_bench/container.py @@ -0,0 +1,81 @@ +"""Docker container lifecycle management for vLLM.""" + +from __future__ import annotations + +import logging +import subprocess + +logger = logging.getLogger(__name__) + +CONTAINER_NAME = "vllm-bench" +VLLM_IMAGE = "vllm/vllm-openai:v0.8.5" + + +def start_vllm( + *, + model: str, + port: int, + model_dir: str, + gpu_memory_utilization: float, +) -> None: + """Start a vLLM container serving the given model. + + Args: + model: HuggingFace model directory name (relative to model_dir). + port: Host port to bind. + model_dir: Host path containing HuggingFace model directories. + gpu_memory_utilization: Fraction of GPU memory to use (0-1). + """ + command = [ + "docker", + "run", + "-d", + "--name", + CONTAINER_NAME, + "--device=nvidia.com/gpu=all", + "--ipc=host", + "-v", + f"{model_dir}:/models", + "-p", + f"{port}:8000", + VLLM_IMAGE, + "--model", + f"/models/{model}", + "--served-model-name", + model, + "--gpu-memory-utilization", + str(gpu_memory_utilization), + "--max-model-len", + "4096", + ] + logger.info("Starting vLLM container with model: %s", model) + result = subprocess.run(command, capture_output=True, text=True, check=False) + if result.returncode != 0: + msg = f"Failed to start vLLM container: {result.stderr.strip()}" + raise RuntimeError(msg) + logger.info("vLLM container started: %s", result.stdout.strip()[:12]) + + +def stop_vllm() -> None: + """Stop and remove the vLLM benchmark container.""" + logger.info("Stopping vLLM container") + subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False) + subprocess.run(["docker", "rm", CONTAINER_NAME], capture_output=True, check=False) + logger.info("vLLM container stopped and removed") + + +def check_gpu_free() -> None: + """Warn if GPU-heavy processes (e.g. Ollama) are running.""" + result = subprocess.run( + ["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + logger.warning("Could not query GPU processes: %s", result.stderr.strip()) + return + processes = result.stdout.strip() + if processes: + logger.warning("GPU processes detected:\n%s", processes) + logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking") diff --git a/python/prompt_bench/main.py b/python/prompt_bench/main.py new file mode 100644 index 0000000..0d39e69 --- /dev/null +++ b/python/prompt_bench/main.py @@ -0,0 +1,218 @@ +"""CLI entry point for the prompt benchmarking system.""" + +from __future__ import annotations + +import json +import logging +import time +import tomllib +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Annotated + +import typer + +from python.prompt_bench.container import check_gpu_free, start_vllm, stop_vllm +from python.prompt_bench.downloader import is_model_present +from python.prompt_bench.models import BenchmarkConfig +from python.prompt_bench.vllm_client import VLLMClient + +logger = logging.getLogger(__name__) + + +def discover_prompts(input_dir: Path) -> list[Path]: + """Find all .txt files in the input directory.""" + prompts = list(input_dir.glob("*.txt")) + if not prompts: + message = f"No .txt files found in {input_dir}" + raise FileNotFoundError(message) + return prompts + + +def _run_prompt( + client: VLLMClient, + prompt_path: Path, + *, + repo: str, + model_dir_name: str, + model_output: Path, + temperature: float, +) -> tuple[bool, float]: + """Run a single prompt. Returns (success, elapsed_seconds).""" + filename = prompt_path.name + output_path = model_output / filename + start = time.monotonic() + try: + prompt_text = prompt_path.read_text() + response = client.complete(prompt_text, model_dir_name, temperature=temperature) + output_path.write_text(response) + elapsed = time.monotonic() - start + logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed) + except Exception: + elapsed = time.monotonic() - start + error_path = model_output / f"{filename}.error" + logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed) + error_path.write_text(f"Error processing {filename}") + return False, elapsed + return True, elapsed + + +def benchmark_model( + client: VLLMClient, + prompts: list[Path], + *, + repo: str, + model_dir_name: str, + model_output: Path, + temperature: float, + concurrency: int, +) -> tuple[int, int]: + """Run all prompts against a single model in parallel. + + vLLM batches concurrent requests internally, so submitting many at once is + significantly faster than running them serially. + """ + pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()] + skipped = len(prompts) - len(pending) + if skipped: + logger.info("Skipping %d prompts with existing output for %s", skipped, repo) + + if not pending: + logger.info("Nothing to do for %s", repo) + return 0, 0 + + completed = 0 + failed = 0 + latencies: list[float] = [] + + wall_start = time.monotonic() + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [ + executor.submit( + _run_prompt, + client, + prompt_path, + repo=repo, + model_dir_name=model_dir_name, + model_output=model_output, + temperature=temperature, + ) + for prompt_path in pending + ] + for future in as_completed(futures): + success, elapsed = future.result() + latencies.append(elapsed) + if success: + completed += 1 + else: + failed += 1 + wall_elapsed = time.monotonic() - wall_start + + attempted = completed + failed + avg_latency = sum(latencies) / attempted + throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0 + timing = { + "repo": repo, + "wall_seconds": wall_elapsed, + "attempted": attempted, + "completed": completed, + "failed": failed, + "avg_latency_seconds": avg_latency, + "throughput_prompts_per_second": throughput, + "concurrency": concurrency, + } + timing_path = model_output / "_timing.json" + timing_path.write_text(json.dumps(timing, indent=2)) + + return completed, failed + + +def run_benchmark( + config: BenchmarkConfig, + input_dir: Path, + output_dir: Path, +) -> None: + """Execute the benchmark across all models and prompts.""" + prompts = discover_prompts(input_dir) + logger.info("Found %d prompts in %s", len(prompts), input_dir) + + check_gpu_free() + + total_completed = 0 + total_failed = 0 + + for repo in config.models: + if not is_model_present(repo, config.model_dir): + logger.warning("Skipping (not downloaded): %s", repo) + continue + + model_output = output_dir / repo + model_output.mkdir(parents=True, exist_ok=True) + + logger.info("=== Benchmarking model: %s ===", repo) + + stop_vllm() + try: + start_vllm( + model=repo, + port=config.port, + model_dir=config.model_dir, + gpu_memory_utilization=config.gpu_memory_utilization, + ) + except RuntimeError: + logger.exception("Failed to start vLLM for %s, skipping", repo) + continue + logger.info("vLLM started for %s", repo) + try: + with VLLMClient(port=config.port, timeout=config.timeout) as client: + client.wait_ready(max_wait=config.vllm_startup_timeout) + completed, failed = benchmark_model( + client, + prompts, + repo=repo, + model_dir_name=repo, + model_output=model_output, + temperature=config.temperature, + concurrency=config.concurrency, + ) + total_completed += completed + total_failed += failed + finally: + stop_vllm() + + logger.info("=== Benchmark complete ===") + logger.info("Completed: %d | Failed: %d", total_completed, total_failed) + + +def main( + input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")], + config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"), + output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"), + log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", +) -> None: + """Run prompts through multiple LLMs via vLLM and save results.""" + logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not input_dir.is_dir(): + message = f"Input directory does not exist: {input_dir}" + raise typer.BadParameter(message) + if not config.is_file(): + message = f"Config file does not exist: {config}" + raise typer.BadParameter(message) + + with config.open("rb") as file: + raw = tomllib.load(file) + + benchmark_config = BenchmarkConfig(**raw) + output_dir.mkdir(parents=True, exist_ok=True) + + run_benchmark(benchmark_config, input_dir, output_dir) + + +def cli() -> None: + """Typer entry point.""" + typer.run(main) + + +if __name__ == "__main__": + cli() diff --git a/python/prompt_bench/models.py b/python/prompt_bench/models.py index f07084b..1abae25 100644 --- a/python/prompt_bench/models.py +++ b/python/prompt_bench/models.py @@ -15,3 +15,4 @@ class BenchmarkConfig(BaseModel): temperature: float = 0.0 timeout: int = 300 concurrency: int = 4 + vllm_startup_timeout: int = 900 diff --git a/python/prompt_bench/vllm_client.py b/python/prompt_bench/vllm_client.py new file mode 100644 index 0000000..b7d9045 --- /dev/null +++ b/python/prompt_bench/vllm_client.py @@ -0,0 +1,68 @@ +"""OpenAI-compatible client for vLLM's API.""" + +from __future__ import annotations + +import logging +import time +from typing import Self + +import httpx + +logger = logging.getLogger(__name__) + +READY_POLL_INTERVAL = 2.0 + + +class VLLMClient: + """Talk to a vLLM server via its OpenAI-compatible API. + + Args: + host: vLLM host. + port: vLLM port. + timeout: Per-request timeout in seconds. + """ + + def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None: + """Create a client connected to a vLLM server.""" + self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout) + + def wait_ready(self, max_wait: int) -> None: + """Poll /v1/models until the server is ready or timeout.""" + deadline = time.monotonic() + max_wait + while time.monotonic() < deadline: + try: + response = self._client.get("/v1/models") + if response.is_success: + logger.info("vLLM server is ready") + return + except httpx.TransportError: + pass + time.sleep(READY_POLL_INTERVAL) + msg = f"vLLM server not ready after {max_wait}s" + raise TimeoutError(msg) + + def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str: + """Send a prompt to /v1/completions and return the response text.""" + payload = { + "model": model, + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + } + logger.info("Sending prompt to %s (%d chars)", model, len(prompt)) + response = self._client.post("/v1/completions", json=payload) + response.raise_for_status() + data = response.json() + return data["choices"][0]["text"] + + def close(self) -> None: + """Close the HTTP client.""" + self._client.close() + + def __enter__(self) -> Self: + """Enter the context manager.""" + return self + + def __exit__(self, *args: object) -> None: + """Close the HTTP client on exit.""" + self.close()