made fastapi tools

This commit is contained in:
2026-06-12 14:45:10 -04:00
parent c5418b50fd
commit 479191050e
6 changed files with 10 additions and 4 deletions
+6
View File
@@ -0,0 +1,6 @@
"""Reusable FastAPI tools."""
from python.fastapi_tools.db import DbSession, get_db
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
+16
View File
@@ -0,0 +1,16 @@
"""FastAPI dependencies."""
from collections.abc import Iterator
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.orm import Session
def get_db(request: Request) -> Iterator[Session]:
"""Get database session from app state."""
with Session(request.app.state.engine) as session:
yield session
DbSession = Annotated[Session, Depends(get_db)]
+49
View File
@@ -0,0 +1,49 @@
"""Zstd response compression middleware."""
from compression import zstd
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
MINIMUM_RESPONSE_SIZE = 500
class ZstdMiddleware(BaseHTTPMiddleware):
"""Middleware that compresses responses with zstd when the client supports it."""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
"""Compress the response with zstd if the client accepts it."""
accepted_encodings = request.headers.get("accept-encoding", "")
if "zstd" not in accepted_encodings:
return await call_next(request)
response = await call_next(request)
if response.headers.get("content-encoding") or "text/event-stream" in response.headers.get("content-type", ""):
return response
body = b""
async for chunk in response.body_iterator:
body += chunk if isinstance(chunk, bytes) else chunk.encode()
if len(body) < MINIMUM_RESPONSE_SIZE:
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
compressed = zstd.compress(body)
headers = dict(response.headers)
headers["content-encoding"] = "zstd"
headers["content-length"] = str(len(compressed))
headers.pop("transfer-encoding", None)
return Response(
content=compressed,
status_code=response.status_code,
headers=headers,
media_type=response.media_type,
)