Source code for app.middlewares.valid_access_token

import asyncio
import logging
from typing import Annotated, Any, Optional

import httpx
import jwt
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2AuthorizationCodeBearer
from jwt import PyJWKClient, exceptions

from app.core.config import configs
from app.util.response import Response

logger = logging.getLogger(__name__)


OIDC_SCOPES = {
    "openid": "OpenID Connect scope required for subject claims.",
    "profile": "Basic profile claims.",
    "email": "Email claims.",
    "greencrowd-roles": "GreenCrowd role claims.",
}

SUBJECT_FALLBACK_CLAIMS = (
    "sub",
    "preferred_username",
    "email",
    "client_id",
    "azp",
)


# Lazy module-level singleton. Building a fresh PyJWKClient per request issues
# a JWKS HTTP roundtrip every time; reusing one client lets it cache signing
# keys for ``lifespan`` seconds.
_JWKS_CLIENT: Optional[PyJWKClient] = None


def _get_jwks_client() -> PyJWKClient:
    """
    Return the lazily-built, process-wide ``PyJWKClient`` singleton.

    The client points at the Keycloak realm's JWKS endpoint and caches signing
    keys for its ``lifespan`` so repeated token validations avoid a JWKS HTTP
    roundtrip on every request.

    Returns:
        PyJWKClient: The shared signing-key client.
    """
    global _JWKS_CLIENT
    if _JWKS_CLIENT is None:
        url = f"{configs.KEYCLOAK_URL_DOCKER}/realms/{configs.KEYCLOAK_REALM}/protocol/openid-connect/certs"
        _JWKS_CLIENT = PyJWKClient(
            url,
            headers={"User-agent": "fastapi-jwt-auth/0.1.0 ( GAME )"},
            cache_keys=True,
            lifespan=300,
        )
    return _JWKS_CLIENT


def _reset_jwks_client_for_tests() -> None:
    """Test hook: drop the cached PyJWKClient so monkeypatched classes apply."""
    global _JWKS_CLIENT
    _JWKS_CLIENT = None


[docs] class CustomOAuth2AuthorizationCodeBearer(OAuth2AuthorizationCodeBearer): async def __call__(self, request: Request) -> Optional[str]: try: token = await super().__call__(request) except HTTPException: return False return token
oauth_2_scheme = CustomOAuth2AuthorizationCodeBearer( tokenUrl=f"{configs.KEYCLOAK_URL}/realms/{configs.KEYCLOAK_REALM}/protocol/openid-connect/token", authorizationUrl=f"{configs.KEYCLOAK_URL}/realms/{configs.KEYCLOAK_REALM}/protocol/openid-connect/auth", refreshUrl=f"{configs.KEYCLOAK_URL}/realms/{configs.KEYCLOAK_REALM}/protocol/openid-connect/token", scopes=OIDC_SCOPES, ) def _normalize_subject_claim(payload: dict[str, Any]) -> dict[str, Any]: """ Ensure the token payload has a usable ``sub`` claim. Keycloak access tokens do not always carry ``sub`` (e.g. client-credential tokens). This copies the payload and fills ``sub`` from the first non-empty fallback claim in :data:`SUBJECT_FALLBACK_CLAIMS` (``preferred_username``, ``email``, ``client_id``, ``azp``). Args: payload (dict[str, Any]): Decoded JWT claims. Returns: dict[str, Any]: A copy of the payload with ``sub`` normalized when a suitable fallback exists. """ normalized = dict(payload) for claim_name in SUBJECT_FALLBACK_CLAIMS: claim_value = normalized.get(claim_name) if isinstance(claim_value, str) and claim_value.strip(): normalized["sub"] = claim_value.strip() break return normalized def _build_token_response(payload: dict[str, Any]) -> Response: """ Wrap a decoded JWT payload into a success/failure :class:`Response`. Normalizes the subject claim and rejects tokens that still lack a usable ``sub`` after normalization. Args: payload (dict[str, Any]): Decoded JWT claims. Returns: Response: ``Response.ok`` with the normalized claims, or ``Response.fail`` carrying a 401 ``HTTPException`` when no subject is present. """ normalized = _normalize_subject_claim(payload) subject = normalized.get("sub") if not isinstance(subject, str) or not subject.strip(): return Response.fail( error=HTTPException(status_code=401, detail="Token subject not found") ) return Response.ok(normalized)
[docs] async def valid_access_token( access_token: Annotated[str, Depends(oauth_2_scheme)], ) -> Response: """ Validate a Keycloak-issued JWT access token. Fetches the realm signing key (offloaded to a worker thread so the JWKS roundtrip does not block the event loop) and verifies the token's signature, expiry, issuer and audience. Each failure mode is mapped to an appropriate HTTP status carried inside the returned ``Response``. Args: access_token (str): The bearer token extracted by ``oauth_2_scheme``. Returns: Response: ``Response.ok`` with the decoded/normalized claims on success; ``Response.fail`` wrapping an ``HTTPException`` (401/403/500) describing why validation failed. """ try: jwks_client = _get_jwks_client() # PyJWKClient does sync HTTP under the hood; offload to a worker thread # so a JWKS roundtrip doesn't stall the event loop. signing_key = await asyncio.to_thread( jwks_client.get_signing_key_from_jwt, access_token ) data = jwt.decode( access_token, key=signing_key.key, algorithms=["RS256"], issuer=f"{configs.KEYCLOAK_URL}/realms/{configs.KEYCLOAK_REALM}", audience=configs.KEYCLOAK_AUDIENCE, leeway=30, options={ "verify_exp": True, "verify_aud": True, }, ) return _build_token_response(data) except exceptions.InvalidSignatureError as e: logger.warning("JWT rejected: invalid signature: %s", e) return Response.fail( error=HTTPException(status_code=401, detail="Invalid token signature") ) except exceptions.ExpiredSignatureError as e: logger.warning("JWT rejected: expired: %s", e) return Response.fail( error=HTTPException(status_code=401, detail="Token has expired") ) except exceptions.InvalidAudienceError as e: logger.warning("JWT rejected: invalid audience: %s", e) return Response.fail( error=HTTPException(status_code=403, detail="Invalid audience") ) except exceptions.InvalidTokenError as e: logger.warning("JWT rejected: invalid token: %s", e) return Response.fail( error=HTTPException(status_code=401, detail="Invalid token") ) except exceptions.PyJWKClientError: logger.exception("JWKS client failure fetching signing key") return Response.fail( error=HTTPException(status_code=500, detail="Internal server error") ) except jwt.PyJWTError as e: logger.warning("JWT rejected: %s", e) return Response.fail( error=HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") )
[docs] async def refresh_access_token(refresh_token: str): """ Refresh the access token using the refresh token. Args: refresh_token (str): The refresh token to be used to generate a new access token. Returns: dict: The new access token and other related information. """ url = f"{configs.KEYCLOAK_URL}/realms/{configs.KEYCLOAK_REALM}/protocol/openid-connect/token" data = { "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": configs.KEYCLOAK_CLIENT_ID, "client_secret": configs.KEYCLOAK_CLIENT_SECRET, } try: async with httpx.AsyncClient() as client: response = await client.post(url, data=data) response.raise_for_status() return response.json() except httpx.HTTPStatusError as http_err: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to refresh token: {http_err}", ) except Exception as err: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {err}", )