openrag/src/api/oidc.py
2025-08-11 16:45:54 -04:00

102 lines
No EOL
3.5 KiB
Python

from starlette.requests import Request
from starlette.responses import JSONResponse
import json
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
async def oidc_discovery(request: Request, session_manager):
"""OIDC discovery endpoint"""
base_url = str(request.base_url).rstrip('/')
discovery_config = {
"issuer": base_url,
"authorization_endpoint": f"{base_url}/auth/init",
"token_endpoint": f"{base_url}/auth/callback",
"jwks_uri": f"{base_url}/auth/jwks",
"userinfo_endpoint": f"{base_url}/auth/me",
"response_types_supported": ["code"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
"scopes_supported": ["openid", "email", "profile"],
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
"claims_supported": [
"sub", "iss", "aud", "exp", "iat", "auth_time",
"email", "email_verified", "name", "preferred_username"
]
}
return JSONResponse(discovery_config)
async def jwks_endpoint(request: Request, session_manager):
"""JSON Web Key Set endpoint"""
try:
# Get the public key from session manager
public_key_pem = session_manager.public_key_pem
# Parse the PEM to extract key components
public_key = serialization.load_pem_public_key(public_key_pem.encode())
# Convert RSA components to base64url
def int_to_base64url(value):
# Convert integer to bytes, then to base64url
byte_length = (value.bit_length() + 7) // 8
value_bytes = value.to_bytes(byte_length, byteorder='big')
return base64.urlsafe_b64encode(value_bytes).decode('ascii').rstrip('=')
# Get public key components
public_numbers = public_key.public_numbers()
jwk = {
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": "gendb-key-1",
"n": int_to_base64url(public_numbers.n),
"e": int_to_base64url(public_numbers.e)
}
jwks = {
"keys": [jwk]
}
return JSONResponse(jwks)
except Exception as e:
return JSONResponse(
{"error": f"Failed to generate JWKS: {str(e)}"},
status_code=500
)
async def token_introspection(request: Request, session_manager):
"""Token introspection endpoint (optional)"""
try:
data = await request.json()
token = data.get("token")
if not token:
return JSONResponse({"active": False})
# Verify the token
payload = session_manager.verify_token(token)
if payload:
return JSONResponse({
"active": True,
"sub": payload.get("sub"),
"aud": payload.get("aud"),
"iss": payload.get("iss"),
"exp": payload.get("exp"),
"iat": payload.get("iat"),
"email": payload.get("email"),
"name": payload.get("name"),
"preferred_username": payload.get("preferred_username")
})
else:
return JSONResponse({"active": False})
except Exception as e:
return JSONResponse(
{"error": f"Token introspection failed: {str(e)}"},
status_code=500
)