203 lines
6.3 KiB
Python
203 lines
6.3 KiB
Python
"""WorkOS AuthKit helpers for the FastAPI web app."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from os import getenv
|
|
from typing import Any
|
|
|
|
from fastapi import Request
|
|
from workos import WorkOSClient
|
|
from workos.session import seal_session_from_auth_response
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthConfig:
|
|
"""Runtime configuration for WorkOS AuthKit."""
|
|
|
|
api_key: str
|
|
client_id: str
|
|
cookie_password: str
|
|
redirect_uri: str
|
|
logout_redirect_uri: str
|
|
session_cookie_name: str
|
|
organization_id: str
|
|
|
|
@property
|
|
def secure_cookies(self) -> bool:
|
|
return self.redirect_uri.startswith("https://")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthSession:
|
|
"""Normalized auth session passed through the app."""
|
|
|
|
user_id: str
|
|
email: str
|
|
first_name: str | None
|
|
last_name: str | None
|
|
role_slugs: set[str]
|
|
organization_id: str | None
|
|
raw_user: Any
|
|
raw_session: Any
|
|
|
|
@property
|
|
def display_name(self) -> str:
|
|
parts = [part for part in (self.first_name, self.last_name) if part]
|
|
return " ".join(parts) if parts else self.email
|
|
|
|
@property
|
|
def is_admin(self) -> bool:
|
|
return "admin" in self.role_slugs
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CallbackResult:
|
|
"""Result of exchanging a WorkOS callback code."""
|
|
|
|
sealed_session: str
|
|
next_path: str
|
|
|
|
|
|
def safe_next_path(value: str | None, default: str = "/dashboard") -> str:
|
|
"""Allow only local relative redirect targets."""
|
|
if value and value.startswith("/") and not value.startswith("//"):
|
|
return value
|
|
return default
|
|
|
|
|
|
def build_authorization_url(next_path: str) -> str:
|
|
"""Build the WorkOS hosted login URL."""
|
|
config = get_auth_config()
|
|
return get_workos_client().user_management.get_authorization_url(
|
|
provider="authkit",
|
|
redirect_uri=config.redirect_uri,
|
|
state=safe_next_path(next_path),
|
|
organization_id=config.organization_id,
|
|
)
|
|
|
|
|
|
def exchange_code(request: Request) -> CallbackResult:
|
|
"""Exchange a WorkOS callback code for a sealed session cookie value."""
|
|
code = request.query_params.get("code")
|
|
if not code:
|
|
raise ValueError("Missing authentication code.")
|
|
|
|
config = get_auth_config()
|
|
auth_response = get_workos_client().user_management.authenticate_with_code(
|
|
code=code,
|
|
ip_address=_request_ip(request),
|
|
user_agent=request.headers.get("user-agent"),
|
|
)
|
|
sealed_session = seal_session_from_auth_response(
|
|
access_token=auth_response.access_token,
|
|
refresh_token=auth_response.refresh_token,
|
|
user=auth_response.user.to_dict(),
|
|
impersonator=auth_response.impersonator.to_dict()
|
|
if auth_response.impersonator is not None
|
|
else None,
|
|
cookie_password=config.cookie_password,
|
|
)
|
|
|
|
return CallbackResult(
|
|
sealed_session=sealed_session,
|
|
next_path=safe_next_path(request.query_params.get("state")),
|
|
)
|
|
|
|
|
|
def get_current_session(request: Request) -> AuthSession | None:
|
|
"""Load the current signed-in WorkOS session from the sealed cookie."""
|
|
cookie_name = getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session")
|
|
sealed_session = request.cookies.get(cookie_name)
|
|
if not sealed_session:
|
|
return None
|
|
|
|
config = get_auth_config()
|
|
session = get_workos_client().user_management.load_sealed_session(
|
|
session_data=sealed_session,
|
|
cookie_password=config.cookie_password,
|
|
)
|
|
auth_response = session.authenticate()
|
|
if not getattr(auth_response, "authenticated", False):
|
|
return None
|
|
|
|
user = auth_response.user or {}
|
|
organization_id = getattr(auth_response, "organization_id", None)
|
|
if config.organization_id and organization_id != config.organization_id:
|
|
return None
|
|
role_slugs = set(getattr(auth_response, "roles", None) or [])
|
|
role = getattr(auth_response, "role", None)
|
|
if role:
|
|
role_slugs.add(role)
|
|
|
|
return AuthSession(
|
|
user_id=_user_field(user, "id") or "",
|
|
email=_user_field(user, "email") or "",
|
|
first_name=_user_field(user, "first_name"),
|
|
last_name=_user_field(user, "last_name"),
|
|
role_slugs=role_slugs,
|
|
organization_id=organization_id,
|
|
raw_user=user,
|
|
raw_session=auth_response,
|
|
)
|
|
|
|
|
|
def get_logout_url(request: Request) -> str:
|
|
"""Return the WorkOS logout URL for the current sealed session."""
|
|
config = get_auth_config()
|
|
sealed_session = request.cookies.get(config.session_cookie_name)
|
|
if not sealed_session:
|
|
return config.logout_redirect_uri
|
|
|
|
session = get_workos_client().user_management.load_sealed_session(
|
|
session_data=sealed_session,
|
|
cookie_password=config.cookie_password,
|
|
)
|
|
return session.get_logout_url(return_to=config.logout_redirect_uri)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_auth_config() -> AuthConfig:
|
|
"""Load and validate WorkOS environment configuration."""
|
|
values = {
|
|
"WORKOS_API_KEY": getenv("WORKOS_API_KEY"),
|
|
"WORKOS_CLIENT_ID": getenv("WORKOS_CLIENT_ID"),
|
|
"WORKOS_COOKIE_PASSWORD": getenv("WORKOS_COOKIE_PASSWORD"),
|
|
"WORKOS_ORGANIZATION_ID": getenv("WORKOS_ORGANIZATION_ID"),
|
|
}
|
|
missing = [name for name, value in values.items() if not value]
|
|
if missing:
|
|
raise RuntimeError(
|
|
"Missing WorkOS configuration: " + ", ".join(sorted(missing))
|
|
)
|
|
|
|
return AuthConfig(
|
|
api_key=values["WORKOS_API_KEY"] or "",
|
|
client_id=values["WORKOS_CLIENT_ID"] or "",
|
|
cookie_password=values["WORKOS_COOKIE_PASSWORD"] or "",
|
|
redirect_uri=getenv("WORKOS_REDIRECT_URI", "http://localhost:8000/callback"),
|
|
logout_redirect_uri=getenv("WORKOS_LOGOUT_REDIRECT_URI", "http://localhost:8000/"),
|
|
session_cookie_name=getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session"),
|
|
organization_id=values["WORKOS_ORGANIZATION_ID"] or "",
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_workos_client():
|
|
"""Create and cache the WorkOS SDK client."""
|
|
config = get_auth_config()
|
|
return WorkOSClient(api_key=config.api_key, client_id=config.client_id)
|
|
|
|
|
|
def _request_ip(request: Request) -> str | None:
|
|
if request.client is None:
|
|
return None
|
|
return request.client.host
|
|
|
|
|
|
def _user_field(user: Any, key: str) -> Any:
|
|
if isinstance(user, dict):
|
|
return user.get(key)
|
|
return getattr(user, key, None)
|