reworded fastapi code

This commit is contained in:
2026-01-20 10:22:04 -05:00
parent cf4635922e
commit 258f918794
8 changed files with 339 additions and 106 deletions

View File

@@ -0,0 +1,16 @@
"""FastAPI dependencies."""
from collections.abc import Iterator
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.orm import Session
def get_db(request: Request) -> Iterator[Session]:
"""Get database session from app state."""
with Session(request.app.state.engine) as session:
yield session
DbSession = Annotated[Session, Depends(get_db)]

113
python/api/main.py Normal file
View File

@@ -0,0 +1,113 @@
"""FastAPI interface for Contact database."""
import shutil
import subprocess
import tempfile
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from os import environ
from pathlib import Path
from typing import Annotated
import typer
import uvicorn
from fastapi import FastAPI
from python.api.routers import contact_router, create_frontend_router
from python.orm.base import get_postgres_engine
def create_app(frontend_dir: Path | None = None) -> FastAPI:
"""Create and configure the FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application lifespan."""
app.state.engine = get_postgres_engine()
yield
app.state.engine.dispose()
app = FastAPI(title="Contact Database API", lifespan=lifespan)
app.include_router(contact_router)
if frontend_dir:
print(f"Serving frontend from {frontend_dir}")
frontend_router = create_frontend_router(frontend_dir)
app.include_router(frontend_router)
return app
cli = typer.Typer()
def build_frontend(source_dir: Path | None, cache_dir: Path | None = None) -> Path | None:
"""Run npm build and copy output to a temp directory.
Works even if source_dir is read-only by copying to a temp directory first.
Args:
source_dir: Frontend source directory.
cache_dir: Optional npm cache directory for faster repeated builds.
Returns:
Path to frontend build directory, or None if no source_dir provided.
"""
if not source_dir:
return None
if not source_dir.exists():
error = f"Error: Frontend directory {source_dir} does not exist"
raise FileExistsError(error)
print(f"Building frontend from {source_dir}...")
# Copy source to a writable temp directory
build_dir = Path(tempfile.mkdtemp(prefix="contact_frontend_build_"))
shutil.copytree(source_dir, build_dir, dirs_exist_ok=True)
env = dict(environ)
if cache_dir:
cache_dir.mkdir(parents=True, exist_ok=True)
env["npm_config_cache"] = str(cache_dir)
subprocess.run(["npm", "install"], cwd=build_dir, env=env, check=True)
subprocess.run(["npm", "run", "build"], cwd=build_dir, env=env, check=True)
dist_dir = build_dir / "dist"
if not dist_dir.exists():
error = f"Build output not found at {dist_dir}"
raise FileNotFoundError(error)
output_dir = Path(tempfile.mkdtemp(prefix="contact_frontend_"))
shutil.copytree(dist_dir, output_dir, dirs_exist_ok=True)
print(f"Frontend built and copied to {output_dir}")
shutil.rmtree(build_dir)
return output_dir
@cli.command()
def serve(
frontend_dir: Annotated[
Path | None,
typer.Option(
"--frontend-dir",
"-f",
help="Frontend source directory. If provided, runs npm build and serves from temp dir.",
),
] = None,
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "0.0.0.0",
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8000,
) -> None:
"""Start the Contact API server."""
serve_dir = build_frontend(frontend_dir)
app = create_app(frontend_dir=serve_dir)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,6 @@
"""API routers."""
from python.api.routers.contact import router as contact_router
from python.api.routers.frontend import create_frontend_router
__all__ = ["contact_router", "create_frontend_router"]

View File

@@ -1,23 +1,13 @@
"""FastAPI interface for Contact database."""
"""Contact API router."""
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, selectinload
from sqlalchemy.orm import selectinload
from python.orm.base import get_postgres_engine
from python.api.dependencies import DbSession
from python.orm.contact import Contact, ContactRelationship, Need, RelationshipType
FRONTEND_DIR = Path(__file__).parent.parent.parent / "frontend" / "dist"
class NeedBase(BaseModel):
"""Base schema for Need."""
@@ -162,55 +152,10 @@ class ContactListResponse(ContactBase):
model_config = {"from_attributes": True}
class DatabaseSession:
"""Database session manager."""
def __init__(self) -> None:
"""Initialize with no engine."""
self._engine: Engine | None = None
@property
def engine(self) -> Engine:
"""Get or create the database engine."""
if self._engine is None:
self._engine = get_postgres_engine()
return self._engine
def get_session(self) -> Iterator[Session]:
"""Yield a database session."""
with Session(self.engine) as session:
yield session
def dispose(self) -> None:
"""Dispose of the engine."""
if self._engine is not None:
self._engine.dispose()
self._engine = None
router = APIRouter(prefix="/api", tags=["contacts"])
db_manager = DatabaseSession()
def get_db() -> Iterator[Session]:
"""Get database session dependency."""
yield from db_manager.get_session()
DbSession = Annotated[Session, Depends(get_db)]
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Manage application lifespan."""
yield
db_manager.dispose()
app = FastAPI(title="Contact Database API", lifespan=lifespan)
# API routes
@app.post("/api/needs", response_model=NeedResponse)
@router.post("/needs", response_model=NeedResponse)
def create_need(need: NeedCreate, db: DbSession) -> Need:
"""Create a new need."""
db_need = Need(name=need.name, description=need.description)
@@ -220,13 +165,13 @@ def create_need(need: NeedCreate, db: DbSession) -> Need:
return db_need
@app.get("/api/needs", response_model=list[NeedResponse])
@router.get("/needs", response_model=list[NeedResponse])
def list_needs(db: DbSession) -> list[Need]:
"""List all needs."""
return list(db.scalars(select(Need)).all())
@app.get("/api/needs/{need_id}", response_model=NeedResponse)
@router.get("/needs/{need_id}", response_model=NeedResponse)
def get_need(need_id: int, db: DbSession) -> Need:
"""Get a need by ID."""
need = db.get(Need, need_id)
@@ -235,7 +180,7 @@ def get_need(need_id: int, db: DbSession) -> Need:
return need
@app.delete("/api/needs/{need_id}")
@router.delete("/needs/{need_id}")
def delete_need(need_id: int, db: DbSession) -> dict[str, bool]:
"""Delete a need by ID."""
need = db.get(Need, need_id)
@@ -246,7 +191,7 @@ def delete_need(need_id: int, db: DbSession) -> dict[str, bool]:
return {"deleted": True}
@app.post("/api/contacts", response_model=ContactResponse)
@router.post("/contacts", response_model=ContactResponse)
def create_contact(contact: ContactCreate, db: DbSession) -> Contact:
"""Create a new contact."""
need_ids = contact.need_ids
@@ -263,7 +208,7 @@ def create_contact(contact: ContactCreate, db: DbSession) -> Contact:
return db_contact
@app.get("/api/contacts", response_model=list[ContactListResponse])
@router.get("/contacts", response_model=list[ContactListResponse])
def list_contacts(
db: DbSession,
skip: int = 0,
@@ -273,7 +218,7 @@ def list_contacts(
return list(db.scalars(select(Contact).offset(skip).limit(limit)).all())
@app.get("/api/contacts/{contact_id}", response_model=ContactResponse)
@router.get("/contacts/{contact_id}", response_model=ContactResponse)
def get_contact(contact_id: int, db: DbSession) -> Contact:
"""Get a contact by ID with all relationships."""
contact = db.scalar(
@@ -290,7 +235,7 @@ def get_contact(contact_id: int, db: DbSession) -> Contact:
return contact
@app.patch("/api/contacts/{contact_id}", response_model=ContactResponse)
@router.patch("/contacts/{contact_id}", response_model=ContactResponse)
def update_contact(
contact_id: int,
contact: ContactUpdate,
@@ -316,7 +261,7 @@ def update_contact(
return db_contact
@app.delete("/api/contacts/{contact_id}")
@router.delete("/contacts/{contact_id}")
def delete_contact(contact_id: int, db: DbSession) -> dict[str, bool]:
"""Delete a contact by ID."""
contact = db.get(Contact, contact_id)
@@ -327,7 +272,7 @@ def delete_contact(contact_id: int, db: DbSession) -> dict[str, bool]:
return {"deleted": True}
@app.post("/api/contacts/{contact_id}/needs/{need_id}")
@router.post("/contacts/{contact_id}/needs/{need_id}")
def add_need_to_contact(
contact_id: int,
need_id: int,
@@ -349,7 +294,7 @@ def add_need_to_contact(
return {"added": True}
@app.delete("/api/contacts/{contact_id}/needs/{need_id}")
@router.delete("/contacts/{contact_id}/needs/{need_id}")
def remove_need_from_contact(
contact_id: int,
need_id: int,
@@ -371,8 +316,8 @@ def remove_need_from_contact(
return {"removed": True}
@app.post(
"/api/contacts/{contact_id}/relationships",
@router.post(
"/contacts/{contact_id}/relationships",
response_model=ContactRelationshipResponse,
)
def add_contact_relationship(
@@ -409,8 +354,8 @@ def add_contact_relationship(
return db_relationship
@app.get(
"/api/contacts/{contact_id}/relationships",
@router.get(
"/contacts/{contact_id}/relationships",
response_model=list[ContactRelationshipResponse],
)
def get_contact_relationships(
@@ -422,21 +367,15 @@ def get_contact_relationships(
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
outgoing = list(
db.scalars(
select(ContactRelationship).where(ContactRelationship.contact_id == contact_id)
).all()
)
outgoing = list(db.scalars(select(ContactRelationship).where(ContactRelationship.contact_id == contact_id)).all())
incoming = list(
db.scalars(
select(ContactRelationship).where(ContactRelationship.related_contact_id == contact_id)
).all()
db.scalars(select(ContactRelationship).where(ContactRelationship.related_contact_id == contact_id)).all()
)
return outgoing + incoming
@app.patch(
"/api/contacts/{contact_id}/relationships/{related_contact_id}",
@router.patch(
"/contacts/{contact_id}/relationships/{related_contact_id}",
response_model=ContactRelationshipResponse,
)
def update_contact_relationship(
@@ -465,7 +404,7 @@ def update_contact_relationship(
return relationship
@app.delete("/api/contacts/{contact_id}/relationships/{related_contact_id}")
@router.delete("/contacts/{contact_id}/relationships/{related_contact_id}")
def remove_contact_relationship(
contact_id: int,
related_contact_id: int,
@@ -486,7 +425,7 @@ def remove_contact_relationship(
return {"deleted": True}
@app.get("/api/relationship-types")
@router.get("/relationship-types")
def list_relationship_types() -> list[RelationshipTypeInfo]:
"""List all available relationship types with their default weights."""
return [
@@ -499,16 +438,13 @@ def list_relationship_types() -> list[RelationshipTypeInfo]:
]
@app.get("/api/graph")
@router.get("/graph")
def get_relationship_graph(db: DbSession) -> GraphData:
"""Get all contacts and relationships as graph data for visualization."""
contacts = list(db.scalars(select(Contact)).all())
relationships = list(db.scalars(select(ContactRelationship)).all())
nodes = [
GraphNode(id=c.id, name=c.name, current_job=c.current_job)
for c in contacts
]
nodes = [GraphNode(id=c.id, name=c.name, current_job=c.current_job) for c in contacts]
edges = [
GraphEdge(
@@ -521,16 +457,3 @@ def get_relationship_graph(db: DbSession) -> GraphData:
]
return GraphData(nodes=nodes, edges=edges)
# Serve React frontend
if FRONTEND_DIR.exists():
app.mount("/assets", StaticFiles(directory=FRONTEND_DIR / "assets"), name="assets")
@app.get("/{full_path:path}")
async def serve_spa(full_path: str) -> FileResponse:
"""Serve React SPA for all non-API routes."""
file_path = FRONTEND_DIR / full_path
if file_path.is_file():
return FileResponse(file_path)
return FileResponse(FRONTEND_DIR / "index.html")

View File

@@ -0,0 +1,24 @@
"""Frontend SPA router."""
from pathlib import Path
from fastapi import APIRouter
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
def create_frontend_router(frontend_dir: Path) -> APIRouter:
"""Create a router for serving the frontend SPA."""
router = APIRouter(tags=["frontend"])
router.mount("/assets", StaticFiles(directory=frontend_dir / "assets"), name="assets")
@router.get("/{full_path:path}")
async def serve_spa(full_path: str) -> FileResponse:
"""Serve React SPA for all non-API routes."""
file_path = frontend_dir / full_path
if file_path.is_file():
return FileResponse(file_path)
return FileResponse(frontend_dir / "index.html")
return router