"""WorkOS AuthKit helpers for the FastAPI web app.""" from __future__ import annotations from dataclasses import dataclass from functools import lru_cache from os import getenv from typing import Any from fastapi import Request from workos import WorkOSClient from workos.session import seal_session_from_auth_response @dataclass(frozen=True) class AuthConfig: """Runtime configuration for WorkOS AuthKit.""" api_key: str client_id: str cookie_password: str redirect_uri: str logout_redirect_uri: str session_cookie_name: str organization_id: str @property def secure_cookies(self) -> bool: return self.redirect_uri.startswith("https://") @dataclass(frozen=True) class AuthSession: """Normalized auth session passed through the app.""" user_id: str email: str first_name: str | None last_name: str | None role_slugs: set[str] organization_id: str | None raw_user: Any raw_session: Any @property def display_name(self) -> str: parts = [part for part in (self.first_name, self.last_name) if part] return " ".join(parts) if parts else self.email @property def is_admin(self) -> bool: return "admin" in self.role_slugs @dataclass(frozen=True) class CallbackResult: """Result of exchanging a WorkOS callback code.""" sealed_session: str next_path: str def safe_next_path(value: str | None, default: str = "/dashboard") -> str: """Allow only local relative redirect targets.""" if value and value.startswith("/") and not value.startswith("//"): return value return default def build_authorization_url(next_path: str) -> str: """Build the WorkOS hosted login URL.""" config = get_auth_config() return get_workos_client().user_management.get_authorization_url( provider="authkit", redirect_uri=config.redirect_uri, state=safe_next_path(next_path), organization_id=config.organization_id, ) def exchange_code(request: Request) -> CallbackResult: """Exchange a WorkOS callback code for a sealed session cookie value.""" code = request.query_params.get("code") if not code: raise ValueError("Missing authentication code.") config = get_auth_config() auth_response = get_workos_client().user_management.authenticate_with_code( code=code, ip_address=_request_ip(request), user_agent=request.headers.get("user-agent"), ) sealed_session = seal_session_from_auth_response( access_token=auth_response.access_token, refresh_token=auth_response.refresh_token, user=auth_response.user.to_dict(), impersonator=auth_response.impersonator.to_dict() if auth_response.impersonator is not None else None, cookie_password=config.cookie_password, ) return CallbackResult( sealed_session=sealed_session, next_path=safe_next_path(request.query_params.get("state")), ) def get_current_session(request: Request) -> AuthSession | None: """Load the current signed-in WorkOS session from the sealed cookie.""" cookie_name = getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session") sealed_session = request.cookies.get(cookie_name) if not sealed_session: return None config = get_auth_config() try: session = get_workos_client().user_management.load_sealed_session( session_data=sealed_session, cookie_password=config.cookie_password, ) auth_response = session.authenticate() except ValueError: return None if not getattr(auth_response, "authenticated", False): return None user = auth_response.user or {} organization_id = getattr(auth_response, "organization_id", None) if config.organization_id and organization_id != config.organization_id: return None role_slugs = set(getattr(auth_response, "roles", None) or []) role = getattr(auth_response, "role", None) if role: role_slugs.add(role) return AuthSession( user_id=_user_field(user, "id") or "", email=_user_field(user, "email") or "", first_name=_user_field(user, "first_name"), last_name=_user_field(user, "last_name"), role_slugs=role_slugs, organization_id=organization_id, raw_user=user, raw_session=auth_response, ) def get_logout_url(request: Request) -> str: """Return the WorkOS logout URL for the current sealed session.""" config = get_auth_config() sealed_session = request.cookies.get(config.session_cookie_name) if not sealed_session: return config.logout_redirect_uri try: session = get_workos_client().user_management.load_sealed_session( session_data=sealed_session, cookie_password=config.cookie_password, ) return session.get_logout_url(return_to=config.logout_redirect_uri) except ValueError: return config.logout_redirect_uri @lru_cache(maxsize=1) def get_auth_config() -> AuthConfig: """Load and validate WorkOS environment configuration.""" values = { "WORKOS_API_KEY": getenv("WORKOS_API_KEY"), "WORKOS_CLIENT_ID": getenv("WORKOS_CLIENT_ID"), "WORKOS_COOKIE_PASSWORD": getenv("WORKOS_COOKIE_PASSWORD"), "WORKOS_ORGANIZATION_ID": getenv("WORKOS_ORGANIZATION_ID"), } missing = [name for name, value in values.items() if not value] if missing: raise RuntimeError( "Missing WorkOS configuration: " + ", ".join(sorted(missing)) ) return AuthConfig( api_key=values["WORKOS_API_KEY"] or "", client_id=values["WORKOS_CLIENT_ID"] or "", cookie_password=values["WORKOS_COOKIE_PASSWORD"] or "", redirect_uri=getenv("WORKOS_REDIRECT_URI", "http://localhost:8000/callback"), logout_redirect_uri=getenv("WORKOS_LOGOUT_REDIRECT_URI", "http://localhost:8000/"), session_cookie_name=getenv("WORKOS_SESSION_COOKIE_NAME", "workos_session"), organization_id=values["WORKOS_ORGANIZATION_ID"] or "", ) @lru_cache(maxsize=1) def get_workos_client(): """Create and cache the WorkOS SDK client.""" config = get_auth_config() return WorkOSClient(api_key=config.api_key, client_id=config.client_id) def _request_ip(request: Request) -> str | None: if request.client is None: return None return request.client.host def _user_field(user: Any, key: str) -> Any: if isinstance(user, dict): return user.get(key) return getattr(user, key, None)