mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
adding zstd compression to fastapi
This commit is contained in:
@@ -9,6 +9,7 @@ import typer
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from python.api.middleware import ZstdMiddleware
|
||||
from python.api.routers import contact_router, views_router
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
@@ -27,6 +28,7 @@ def create_app() -> FastAPI:
|
||||
app.state.engine.dispose()
|
||||
|
||||
app = FastAPI(title="Contact Database API", lifespan=lifespan)
|
||||
app.add_middleware(ZstdMiddleware)
|
||||
|
||||
app.include_router(contact_router)
|
||||
app.include_router(views_router)
|
||||
|
||||
49
python/api/middleware.py
Normal file
49
python/api/middleware.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Middleware for the FastAPI application."""
|
||||
|
||||
from compression import zstd
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
MINIMUM_RESPONSE_SIZE = 500
|
||||
|
||||
|
||||
class ZstdMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that compresses responses with zstd when the client supports it."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
"""Compress the response with zstd if the client accepts it."""
|
||||
accepted_encodings = request.headers.get("accept-encoding", "")
|
||||
if "zstd" not in accepted_encodings:
|
||||
return await call_next(request)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if response.headers.get("content-encoding") or "text/event-stream" in response.headers.get("content-type", ""):
|
||||
return response
|
||||
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
if len(body) < MINIMUM_RESPONSE_SIZE:
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
compressed = zstd.compress(body)
|
||||
|
||||
headers = dict(response.headers)
|
||||
headers["content-encoding"] = "zstd"
|
||||
headers["content-length"] = str(len(compressed))
|
||||
headers.pop("transfer-encoding", None)
|
||||
|
||||
return Response(
|
||||
content=compressed,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
Reference in New Issue
Block a user