235 lines
No EOL
7.9 KiB
Python
235 lines
No EOL
7.9 KiB
Python
import json
|
|
import jwt
|
|
import httpx
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Optional, Any
|
|
from dataclasses import dataclass, asdict
|
|
from cryptography.hazmat.primitives import serialization
|
|
import os
|
|
from utils.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
from utils.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
@dataclass
|
|
class User:
|
|
"""User information from OAuth provider"""
|
|
|
|
user_id: str # From OAuth sub claim
|
|
email: str
|
|
name: str
|
|
picture: str = None
|
|
provider: str = "google"
|
|
created_at: datetime = None
|
|
last_login: datetime = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.now()
|
|
if self.last_login is None:
|
|
self.last_login = datetime.now()
|
|
|
|
class AnonymousUser(User):
|
|
"""Anonymous user"""
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
user_id="anonymous",
|
|
email="anonymous@localhost",
|
|
name="Anonymous User",
|
|
picture=None,
|
|
provider="none",
|
|
)
|
|
|
|
|
|
|
|
class SessionManager:
|
|
"""Manages user sessions and JWT tokens"""
|
|
|
|
def __init__(
|
|
self,
|
|
secret_key: str = None,
|
|
private_key_path: str = "keys/private_key.pem",
|
|
public_key_path: str = "keys/public_key.pem",
|
|
):
|
|
self.secret_key = secret_key # Keep for backward compatibility
|
|
self.users: Dict[str, User] = {} # user_id -> User
|
|
self.user_opensearch_clients: Dict[
|
|
str, Any
|
|
] = {} # user_id -> OpenSearch client
|
|
|
|
# Load RSA keys
|
|
self.private_key_path = private_key_path
|
|
self.public_key_path = public_key_path
|
|
self._load_rsa_keys()
|
|
|
|
def _load_rsa_keys(self):
|
|
"""Load RSA private and public keys"""
|
|
try:
|
|
with open(self.private_key_path, "rb") as f:
|
|
self.private_key = serialization.load_pem_private_key(
|
|
f.read(), password=None
|
|
)
|
|
|
|
with open(self.public_key_path, "rb") as f:
|
|
self.public_key = serialization.load_pem_public_key(f.read())
|
|
|
|
# Also get public key in PEM format for JWKS
|
|
self.public_key_pem = open(self.public_key_path, "r").read()
|
|
|
|
except FileNotFoundError as e:
|
|
raise Exception(f"RSA key files not found: {e}")
|
|
except Exception as e:
|
|
raise Exception(f"Failed to load RSA keys: {e}")
|
|
|
|
async def get_user_info_from_token(
|
|
self, access_token: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get user info from Google using access token"""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://www.googleapis.com/oauth2/v2/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
logger.error(
|
|
"Failed to get user info",
|
|
status_code=response.status_code,
|
|
response_text=response.text,
|
|
)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting user info", error=str(e))
|
|
return None
|
|
|
|
async def create_user_session(
|
|
self, access_token: str, issuer: str
|
|
) -> Optional[str]:
|
|
"""Create user session from OAuth access token"""
|
|
user_info = await self.get_user_info_from_token(access_token)
|
|
if not user_info:
|
|
return None
|
|
|
|
# Create or update user
|
|
user_id = user_info["id"]
|
|
user = User(
|
|
user_id=user_id,
|
|
email=user_info["email"],
|
|
name=user_info["name"],
|
|
picture=user_info.get("picture"),
|
|
provider="google",
|
|
)
|
|
|
|
# Update last login if user exists
|
|
if user_id in self.users:
|
|
self.users[user_id].last_login = datetime.now()
|
|
else:
|
|
self.users[user_id] = user
|
|
|
|
# Create JWT token using the shared method
|
|
return self.create_jwt_token(user)
|
|
|
|
def create_jwt_token(self, user: User) -> str:
|
|
"""Create JWT token for an existing user"""
|
|
# Use OpenSearch-compatible issuer for OIDC validation
|
|
oidc_issuer = "http://openrag-backend:8000"
|
|
|
|
# Create JWT token with OIDC-compliant claims
|
|
now = datetime.utcnow()
|
|
token_payload = {
|
|
# OIDC standard claims
|
|
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
|
"sub": user.user_id, # Subject (user ID)
|
|
"aud": ["opensearch", "openrag"], # Audience
|
|
"exp": now + timedelta(days=7), # Expiration
|
|
"iat": now, # Issued at
|
|
"auth_time": int(now.timestamp()), # Authentication time
|
|
# Custom claims
|
|
"user_id": user.user_id, # Keep for backward compatibility
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"preferred_username": user.email,
|
|
"email_verified": True,
|
|
"roles": ["openrag_user"], # Backend role for OpenSearch
|
|
}
|
|
|
|
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
|
return token
|
|
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify JWT token and return user info"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
self.public_key,
|
|
algorithms=["RS256"],
|
|
audience=["opensearch", "openrag"],
|
|
)
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
|
|
def get_user(self, user_id: str) -> Optional[User]:
|
|
"""Get user by ID"""
|
|
return self.users.get(user_id)
|
|
|
|
def get_user_from_token(self, token: str) -> Optional[User]:
|
|
"""Get user from JWT token"""
|
|
payload = self.verify_token(token)
|
|
if payload:
|
|
return self.get_user(payload["user_id"])
|
|
return None
|
|
|
|
def get_user_opensearch_client(self, user_id: str, jwt_token: str):
|
|
"""Get or create OpenSearch client for user with their JWT"""
|
|
# Get the effective JWT token (handles anonymous JWT creation)
|
|
jwt_token = self.get_effective_jwt_token(user_id, jwt_token)
|
|
|
|
# Check if we have a cached client for this user
|
|
if user_id not in self.user_opensearch_clients:
|
|
from config.settings import clients
|
|
|
|
self.user_opensearch_clients[user_id] = (
|
|
clients.create_user_opensearch_client(jwt_token)
|
|
)
|
|
|
|
return self.user_opensearch_clients[user_id]
|
|
|
|
def get_effective_jwt_token(self, user_id: str, jwt_token: str) -> str:
|
|
"""Get the effective JWT token, creating anonymous JWT if needed in no-auth mode"""
|
|
from config.settings import is_no_auth_mode
|
|
|
|
logger.debug(
|
|
"get_effective_jwt_token",
|
|
user_id=user_id,
|
|
jwt_token_present=(jwt_token is not None),
|
|
no_auth_mode=is_no_auth_mode(),
|
|
)
|
|
|
|
# In no-auth mode, create anonymous JWT if needed
|
|
if jwt_token is None and (is_no_auth_mode() or user_id in (None, AnonymousUser().user_id)):
|
|
if not hasattr(self, "_anonymous_jwt"):
|
|
# Create anonymous JWT token for OpenSearch OIDC
|
|
logger.debug("Creating anonymous JWT")
|
|
self._anonymous_jwt = self._create_anonymous_jwt()
|
|
logger.debug(
|
|
"Anonymous JWT created", jwt_prefix=self._anonymous_jwt[:50]
|
|
)
|
|
jwt_token = self._anonymous_jwt
|
|
logger.debug("Using anonymous JWT")
|
|
|
|
return jwt_token
|
|
|
|
def _create_anonymous_jwt(self) -> str:
|
|
"""Create JWT token for anonymous user in no-auth mode"""
|
|
anonymous_user = AnonymousUser()
|
|
return self.create_jwt_token(anonymous_user) |