import json import jwt import httpx from datetime import datetime, timedelta from typing import Dict, Optional, Any from dataclasses import dataclass, asdict from cryptography.hazmat.primitives import serialization import os @dataclass class User: """User information from OAuth provider""" user_id: str # From OAuth sub claim email: str name: str picture: str = None provider: str = "google" created_at: datetime = None last_login: datetime = None def __post_init__(self): if self.created_at is None: self.created_at = datetime.now() if self.last_login is None: self.last_login = datetime.now() class SessionManager: """Manages user sessions and JWT tokens""" def __init__(self, secret_key: str = None, private_key_path: str = "keys/private_key.pem", public_key_path: str = "keys/public_key.pem"): self.secret_key = secret_key # Keep for backward compatibility self.users: Dict[str, User] = {} # user_id -> User self.user_opensearch_clients: Dict[str, Any] = {} # user_id -> OpenSearch client # Load RSA keys self.private_key_path = private_key_path self.public_key_path = public_key_path self._load_rsa_keys() def _load_rsa_keys(self): """Load RSA private and public keys""" try: with open(self.private_key_path, 'rb') as f: self.private_key = serialization.load_pem_private_key( f.read(), password=None ) with open(self.public_key_path, 'rb') as f: self.public_key = serialization.load_pem_public_key(f.read()) # Also get public key in PEM format for JWKS self.public_key_pem = open(self.public_key_path, 'r').read() except FileNotFoundError as e: raise Exception(f"RSA key files not found: {e}") except Exception as e: raise Exception(f"Failed to load RSA keys: {e}") async def get_user_info_from_token(self, access_token: str) -> Optional[Dict[str, Any]]: """Get user info from Google using access token""" try: async with httpx.AsyncClient() as client: response = await client.get( "https://www.googleapis.com/oauth2/v2/userinfo", headers={"Authorization": f"Bearer {access_token}"} ) if response.status_code == 200: return response.json() else: print(f"Failed to get user info: {response.status_code} {response.text}") return None except Exception as e: print(f"Error getting user info: {e}") return None async def create_user_session(self, access_token: str, issuer: str) -> Optional[str]: """Create user session from OAuth access token""" user_info = await self.get_user_info_from_token(access_token) if not user_info: return None # Create or update user user_id = user_info["id"] user = User( user_id=user_id, email=user_info["email"], name=user_info["name"], picture=user_info.get("picture"), provider="google" ) # Update last login if user exists if user_id in self.users: self.users[user_id].last_login = datetime.now() else: self.users[user_id] = user # Use provided issuer # Create JWT token with OIDC-compliant claims now = datetime.utcnow() token_payload = { # OIDC standard claims "iss": issuer, # Issuer from request "sub": user_id, # Subject (user ID) "aud": ["opensearch", "openrag"], # Audience "exp": now + timedelta(days=7), # Expiration "iat": now, # Issued at "auth_time": int(now.timestamp()), # Authentication time # Custom claims "user_id": user_id, # Keep for backward compatibility "email": user.email, "name": user.name, "preferred_username": user.email, "email_verified": True, "roles": ["openrag_user"] # Backend role for OpenSearch } token = jwt.encode(token_payload, self.private_key, algorithm="RS256") return token def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify JWT token and return user info""" try: payload = jwt.decode( token, self.public_key, algorithms=["RS256"], audience=["opensearch", "openrag"] ) return payload except jwt.ExpiredSignatureError: return None except jwt.InvalidTokenError: return None def get_user(self, user_id: str) -> Optional[User]: """Get user by ID""" return self.users.get(user_id) def get_user_from_token(self, token: str) -> Optional[User]: """Get user from JWT token""" payload = self.verify_token(token) if payload: return self.get_user(payload["user_id"]) return None def get_user_opensearch_client(self, user_id: str, jwt_token: str): """Get or create OpenSearch client for user with their JWT""" # Check if we have a cached client for this user if user_id not in self.user_opensearch_clients: from config.settings import clients self.user_opensearch_clients[user_id] = clients.create_user_opensearch_client(jwt_token) return self.user_opensearch_clients[user_id]