feat(ebook-search): add load-test CLI for the search service
Add a Typer CLI script that drives POST /search on a running server at a configurable concurrency and reports latency percentiles (p50/p90/p95/p99), throughput, and HTTP status distribution. Queries are drawn from the shared eval JSONL set so load testing and evaluation exercise the same questions.
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
"""Load test for the EPUB search service.
|
||||
|
||||
Drives ``POST /search`` on a running server at a configurable concurrency and reports
|
||||
latency percentiles, throughput, and HTTP status distribution. Queries are drawn from
|
||||
the shared JSONL set (see ``eval/data/queries.jsonl``) that the eval also uses, so load
|
||||
and evaluation exercise the same questions. Answer generation and reranking happen
|
||||
server-side, so this exercises the full retrieval pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.ebook_search.eval.dataset import DEFAULT_QUERIES_PATH, load_gold_queries
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestResult:
|
||||
"""Outcome of a single search request."""
|
||||
|
||||
status_code: int
|
||||
latency_ms: float
|
||||
ok: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoadSummary:
|
||||
"""Aggregate results of a load test run."""
|
||||
|
||||
total: int
|
||||
successes: int
|
||||
failures: int
|
||||
wall_seconds: float
|
||||
throughput_rps: float
|
||||
latency_p50_ms: float
|
||||
latency_p90_ms: float
|
||||
latency_p95_ms: float
|
||||
latency_p99_ms: float
|
||||
latency_mean_ms: float
|
||||
latency_max_ms: float
|
||||
status_counts: dict[int, int]
|
||||
|
||||
|
||||
def load_queries(queries_file: str | None) -> list[str]:
|
||||
"""Return the query strings from the shared JSONL set (or a custom JSONL file)."""
|
||||
path = Path(queries_file) if queries_file else DEFAULT_QUERIES_PATH
|
||||
queries = [gold.query for gold in load_gold_queries(path)]
|
||||
if not queries:
|
||||
msg = f"No queries found in {path}"
|
||||
raise typer.BadParameter(msg)
|
||||
return queries
|
||||
|
||||
|
||||
def pick_query(queries: list[str]) -> str:
|
||||
"""Return a uniformly random query from the pool (not a security context)."""
|
||||
return random.choice(queries) # noqa: S311 load-test query sampling is not security-sensitive
|
||||
|
||||
|
||||
def percentile(values_sorted: list[float], pct: float) -> float:
|
||||
"""Return the linearly-interpolated percentile of a sorted list."""
|
||||
if not values_sorted:
|
||||
return 0.0
|
||||
rank = (pct / 100) * (len(values_sorted) - 1)
|
||||
low = math.floor(rank)
|
||||
high = math.ceil(rank)
|
||||
if low == high:
|
||||
return values_sorted[low]
|
||||
return values_sorted[low] + (values_sorted[high] - values_sorted[low]) * (rank - low)
|
||||
|
||||
|
||||
def summarize(results: list[RequestResult], wall_seconds: float) -> LoadSummary:
|
||||
"""Aggregate per-request results into a load summary."""
|
||||
latencies = sorted(result.latency_ms for result in results)
|
||||
successes = sum(1 for result in results if result.ok)
|
||||
status_counts: dict[int, int] = {}
|
||||
for result in results:
|
||||
status_counts[result.status_code] = status_counts.get(result.status_code, 0) + 1
|
||||
return LoadSummary(
|
||||
total=len(results),
|
||||
successes=successes,
|
||||
failures=len(results) - successes,
|
||||
wall_seconds=wall_seconds,
|
||||
throughput_rps=len(results) / wall_seconds if wall_seconds > 0 else 0.0,
|
||||
latency_p50_ms=percentile(latencies, 50),
|
||||
latency_p90_ms=percentile(latencies, 90),
|
||||
latency_p95_ms=percentile(latencies, 95),
|
||||
latency_p99_ms=percentile(latencies, 99),
|
||||
latency_mean_ms=statistics.fmean(latencies) if latencies else 0.0,
|
||||
latency_max_ms=latencies[-1] if latencies else 0.0,
|
||||
status_counts=status_counts,
|
||||
)
|
||||
|
||||
|
||||
async def send_search(client: httpx.AsyncClient, query: str, *, rerank: bool) -> RequestResult:
|
||||
"""Send one search request and record its status and latency."""
|
||||
data = {"query": query, "rerank": "true"} if rerank else {"query": query}
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
response = await client.post("/search", data=data)
|
||||
except httpx.HTTPError as error:
|
||||
logger.warning("ebook_loadtest_request_failed error=%s", error)
|
||||
return RequestResult(status_code=0, latency_ms=(time.perf_counter() - start) * 1000, ok=False)
|
||||
return RequestResult(
|
||||
status_code=response.status_code,
|
||||
latency_ms=(time.perf_counter() - start) * 1000,
|
||||
ok=response.is_success,
|
||||
)
|
||||
|
||||
|
||||
async def worker(
|
||||
client: httpx.AsyncClient,
|
||||
queue: asyncio.Queue[str],
|
||||
results: list[RequestResult],
|
||||
*,
|
||||
rerank: bool,
|
||||
) -> None:
|
||||
"""Pull queries off the queue and send requests until it is empty."""
|
||||
while True:
|
||||
try:
|
||||
query = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return
|
||||
results.append(await send_search(client, query, rerank=rerank))
|
||||
|
||||
|
||||
async def run_load(
|
||||
*,
|
||||
base_url: str,
|
||||
queries: list[str],
|
||||
request_count: int,
|
||||
concurrency: int,
|
||||
rerank: bool,
|
||||
warmup: int,
|
||||
timeout_seconds: float,
|
||||
) -> LoadSummary:
|
||||
"""Run the load test and return its aggregate summary."""
|
||||
limits = httpx.Limits(max_connections=concurrency, max_keepalive_connections=concurrency)
|
||||
async with httpx.AsyncClient(base_url=base_url, timeout=timeout_seconds, limits=limits) as client:
|
||||
for _ in range(warmup):
|
||||
await send_search(client, pick_query(queries), rerank=rerank)
|
||||
|
||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
for _ in range(request_count):
|
||||
queue.put_nowait(pick_query(queries))
|
||||
|
||||
results: list[RequestResult] = []
|
||||
start = time.perf_counter()
|
||||
workers = [asyncio.create_task(worker(client, queue, results, rerank=rerank)) for _ in range(concurrency)]
|
||||
await asyncio.gather(*workers)
|
||||
wall_seconds = time.perf_counter() - start
|
||||
return summarize(results, wall_seconds)
|
||||
|
||||
|
||||
def print_summary(summary: LoadSummary) -> None:
|
||||
"""Print the load summary to stdout."""
|
||||
typer.echo(f"requests={summary.total} successes={summary.successes} failures={summary.failures}")
|
||||
typer.echo(f"wall={summary.wall_seconds:.2f}s throughput={summary.throughput_rps:.1f} req/s")
|
||||
typer.echo(
|
||||
f"latency_ms p50={summary.latency_p50_ms:.1f} p90={summary.latency_p90_ms:.1f} "
|
||||
f"p95={summary.latency_p95_ms:.1f} p99={summary.latency_p99_ms:.1f} "
|
||||
f"mean={summary.latency_mean_ms:.1f} max={summary.latency_max_ms:.1f}"
|
||||
)
|
||||
status_summary = " ".join(f"{code}={count}" for code, count in sorted(summary.status_counts.items()))
|
||||
typer.echo(f"status {status_summary}")
|
||||
|
||||
|
||||
def main(
|
||||
*,
|
||||
base_url: Annotated[str, typer.Option(help="Base URL of the running service")] = "http://127.0.0.1:8070",
|
||||
request_count: Annotated[int, typer.Option("--requests", help="Total requests to send")] = 200,
|
||||
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 10,
|
||||
rerank: Annotated[bool, typer.Option(help="Request server-side reranking")] = False,
|
||||
warmup: Annotated[int, typer.Option(help="Warmup requests, not measured")] = 5,
|
||||
timeout_seconds: Annotated[float, typer.Option("--timeout", help="Per-request timeout seconds")] = 120.0,
|
||||
queries_file: Annotated[str | None, typer.Option(help="Query JSONL file (defaults to the shared set)")] = None,
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "WARNING",
|
||||
) -> None:
|
||||
"""Load test the search endpoint and report latency and throughput."""
|
||||
configure_logger(log_level)
|
||||
queries = load_queries(queries_file)
|
||||
logger.info(
|
||||
"ebook_loadtest_start base_url=%s requests=%s concurrency=%s rerank=%s queries=%s",
|
||||
base_url,
|
||||
request_count,
|
||||
concurrency,
|
||||
rerank,
|
||||
len(queries),
|
||||
)
|
||||
summary = asyncio.run(
|
||||
run_load(
|
||||
base_url=base_url,
|
||||
queries=queries,
|
||||
request_count=request_count,
|
||||
concurrency=concurrency,
|
||||
rerank=rerank,
|
||||
warmup=warmup,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
)
|
||||
print_summary(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
Reference in New Issue
Block a user