diff --git a/python/api/main.py b/python/api/main.py index 584ad78..3ac65ba 100644 --- a/python/api/main.py +++ b/python/api/main.py @@ -9,6 +9,7 @@ import typer import uvicorn from fastapi import FastAPI +from python.api.middleware import ZstdMiddleware from python.api.routers import contact_router, views_router from python.common import configure_logger from python.orm.common import get_postgres_engine @@ -27,6 +28,7 @@ def create_app() -> FastAPI: app.state.engine.dispose() app = FastAPI(title="Contact Database API", lifespan=lifespan) + app.add_middleware(ZstdMiddleware) app.include_router(contact_router) app.include_router(views_router) diff --git a/python/api/middleware.py b/python/api/middleware.py new file mode 100644 index 0000000..f710a66 --- /dev/null +++ b/python/api/middleware.py @@ -0,0 +1,49 @@ +"""Middleware for the FastAPI application.""" + +from compression import zstd +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +MINIMUM_RESPONSE_SIZE = 500 + + +class ZstdMiddleware(BaseHTTPMiddleware): + """Middleware that compresses responses with zstd when the client supports it.""" + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """Compress the response with zstd if the client accepts it.""" + accepted_encodings = request.headers.get("accept-encoding", "") + if "zstd" not in accepted_encodings: + return await call_next(request) + + response = await call_next(request) + + if response.headers.get("content-encoding") or "text/event-stream" in response.headers.get("content-type", ""): + return response + + body = b"" + async for chunk in response.body_iterator: + body += chunk if isinstance(chunk, bytes) else chunk.encode() + + if len(body) < MINIMUM_RESPONSE_SIZE: + return Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + compressed = zstd.compress(body) + + headers = dict(response.headers) + headers["content-encoding"] = "zstd" + headers["content-length"] = str(len(compressed)) + headers.pop("transfer-encoding", None) + + return Response( + content=compressed, + status_code=response.status_code, + headers=headers, + media_type=response.media_type, + )