72 lines
No EOL
2.4 KiB
Python
72 lines
No EOL
2.4 KiB
Python
from typing import Dict, Optional, List
|
|
|
|
from fastapi import HTTPException
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jose import jwt, jwk, JWTError
|
|
from jose.utils import base64url_decode
|
|
from pydantic import BaseModel
|
|
from starlette.requests import Request
|
|
from starlette.status import HTTP_403_FORBIDDEN
|
|
|
|
JWK = Dict[str, str]
|
|
|
|
|
|
class JWKS(BaseModel):
|
|
keys: List[JWK]
|
|
|
|
|
|
class JWTAuthorizationCredentials(BaseModel):
|
|
jwt_token: str
|
|
header: Dict[str, str]
|
|
claims: Dict[str, str]
|
|
signature: str
|
|
message: str
|
|
|
|
|
|
class JWTBearer(HTTPBearer):
|
|
def __init__(self, jwks: JWKS, auto_error: bool = True):
|
|
super().__init__(auto_error=auto_error)
|
|
|
|
self.kid_to_jwk = {jwk["kid"]: jwk for jwk in jwks.keys}
|
|
|
|
def verify_jwk_token(self, jwt_credentials: JWTAuthorizationCredentials) -> bool:
|
|
try:
|
|
public_key = self.kid_to_jwk[jwt_credentials.header["kid"]]
|
|
except KeyError:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN, detail="JWK public key not found"
|
|
)
|
|
|
|
key = jwk.construct(public_key)
|
|
decoded_signature = base64url_decode(jwt_credentials.signature.encode())
|
|
|
|
return key.verify(jwt_credentials.message.encode(), decoded_signature)
|
|
|
|
async def __call__(self, request: Request) -> Optional[JWTAuthorizationCredentials]:
|
|
credentials: HTTPAuthorizationCredentials = await super().__call__(request)
|
|
|
|
if credentials:
|
|
if not credentials.scheme == "Bearer":
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN, detail="Wrong authentication method"
|
|
)
|
|
|
|
jwt_token = credentials.credentials
|
|
|
|
message, signature = jwt_token.rsplit(".", 1)
|
|
|
|
try:
|
|
jwt_credentials = JWTAuthorizationCredentials(
|
|
jwt_token=jwt_token,
|
|
header=jwt.get_unverified_header(jwt_token),
|
|
claims=jwt.get_unverified_claims(jwt_token),
|
|
signature=signature,
|
|
message=message,
|
|
)
|
|
except JWTError:
|
|
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="JWK invalid")
|
|
|
|
if not self.verify_jwk_token(jwt_credentials):
|
|
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="JWK invalid")
|
|
|
|
return jwt_credentials |