openrag/src/session_manager.py

162 lines
No EOL
5.8 KiB
Python

import json
import jwt
import httpx
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict
from cryptography.hazmat.primitives import serialization
import os
@dataclass
class User:
"""User information from OAuth provider"""
user_id: str # From OAuth sub claim
email: str
name: str
picture: str = None
provider: str = "google"
created_at: datetime = None
last_login: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if self.last_login is None:
self.last_login = datetime.now()
class SessionManager:
"""Manages user sessions and JWT tokens"""
def __init__(self, secret_key: str = None, private_key_path: str = "keys/private_key.pem",
public_key_path: str = "keys/public_key.pem"):
self.secret_key = secret_key # Keep for backward compatibility
self.users: Dict[str, User] = {} # user_id -> User
self.user_opensearch_clients: Dict[str, Any] = {} # user_id -> OpenSearch client
# Load RSA keys
self.private_key_path = private_key_path
self.public_key_path = public_key_path
self._load_rsa_keys()
def _load_rsa_keys(self):
"""Load RSA private and public keys"""
try:
with open(self.private_key_path, 'rb') as f:
self.private_key = serialization.load_pem_private_key(
f.read(),
password=None
)
with open(self.public_key_path, 'rb') as f:
self.public_key = serialization.load_pem_public_key(f.read())
# Also get public key in PEM format for JWKS
self.public_key_pem = open(self.public_key_path, 'r').read()
except FileNotFoundError as e:
raise Exception(f"RSA key files not found: {e}")
except Exception as e:
raise Exception(f"Failed to load RSA keys: {e}")
async def get_user_info_from_token(self, access_token: str) -> Optional[Dict[str, Any]]:
"""Get user info from Google using access token"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://www.googleapis.com/oauth2/v2/userinfo",
headers={"Authorization": f"Bearer {access_token}"}
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to get user info: {response.status_code} {response.text}")
return None
except Exception as e:
print(f"Error getting user info: {e}")
return None
async def create_user_session(self, access_token: str, issuer: str) -> Optional[str]:
"""Create user session from OAuth access token"""
user_info = await self.get_user_info_from_token(access_token)
if not user_info:
return None
# Create or update user
user_id = user_info["id"]
user = User(
user_id=user_id,
email=user_info["email"],
name=user_info["name"],
picture=user_info.get("picture"),
provider="google"
)
# Update last login if user exists
if user_id in self.users:
self.users[user_id].last_login = datetime.now()
else:
self.users[user_id] = user
# Use provided issuer
# Create JWT token with OIDC-compliant claims
now = datetime.utcnow()
token_payload = {
# OIDC standard claims
"iss": issuer, # Issuer from request
"sub": user_id, # Subject (user ID)
"aud": ["opensearch", "openrag"], # Audience
"exp": now + timedelta(days=7), # Expiration
"iat": now, # Issued at
"auth_time": int(now.timestamp()), # Authentication time
# Custom claims
"user_id": user_id, # Keep for backward compatibility
"email": user.email,
"name": user.name,
"preferred_username": user.email,
"email_verified": True,
"roles": ["openrag_user"] # Backend role for OpenSearch
}
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
return token
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token and return user info"""
try:
payload = jwt.decode(
token,
self.public_key,
algorithms=["RS256"],
audience=["opensearch", "openrag"]
)
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def get_user(self, user_id: str) -> Optional[User]:
"""Get user by ID"""
return self.users.get(user_id)
def get_user_from_token(self, token: str) -> Optional[User]:
"""Get user from JWT token"""
payload = self.verify_token(token)
if payload:
return self.get_user(payload["user_id"])
return None
def get_user_opensearch_client(self, user_id: str, jwt_token: str):
"""Get or create OpenSearch client for user with their JWT"""
# Check if we have a cached client for this user
if user_id not in self.user_opensearch_clients:
from config.settings import clients
self.user_opensearch_clients[user_id] = clients.create_user_opensearch_client(jwt_token)
return self.user_opensearch_clients[user_id]