mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
"""Middleware for the FastAPI application."""
|
|
|
|
from compression import zstd
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
MINIMUM_RESPONSE_SIZE = 500
|
|
|
|
|
|
class ZstdMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that compresses responses with zstd when the client supports it."""
|
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
"""Compress the response with zstd if the client accepts it."""
|
|
accepted_encodings = request.headers.get("accept-encoding", "")
|
|
if "zstd" not in accepted_encodings:
|
|
return await call_next(request)
|
|
|
|
response = await call_next(request)
|
|
|
|
if response.headers.get("content-encoding") or "text/event-stream" in response.headers.get("content-type", ""):
|
|
return response
|
|
|
|
body = b""
|
|
async for chunk in response.body_iterator:
|
|
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
|
|
|
if len(body) < MINIMUM_RESPONSE_SIZE:
|
|
return Response(
|
|
content=body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type,
|
|
)
|
|
|
|
compressed = zstd.compress(body)
|
|
|
|
headers = dict(response.headers)
|
|
headers["content-encoding"] = "zstd"
|
|
headers["content-length"] = str(len(compressed))
|
|
headers.pop("transfer-encoding", None)
|
|
|
|
return Response(
|
|
content=compressed,
|
|
status_code=response.status_code,
|
|
headers=headers,
|
|
media_type=response.media_type,
|
|
)
|