162 lines
No EOL
5.8 KiB
Python
162 lines
No EOL
5.8 KiB
Python
import json
|
|
import jwt
|
|
import httpx
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Optional, Any
|
|
from dataclasses import dataclass, asdict
|
|
from cryptography.hazmat.primitives import serialization
|
|
import os
|
|
|
|
|
|
@dataclass
|
|
class User:
|
|
"""User information from OAuth provider"""
|
|
user_id: str # From OAuth sub claim
|
|
email: str
|
|
name: str
|
|
picture: str = None
|
|
provider: str = "google"
|
|
created_at: datetime = None
|
|
last_login: datetime = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.now()
|
|
if self.last_login is None:
|
|
self.last_login = datetime.now()
|
|
|
|
|
|
class SessionManager:
|
|
"""Manages user sessions and JWT tokens"""
|
|
|
|
def __init__(self, secret_key: str = None, private_key_path: str = "keys/private_key.pem",
|
|
public_key_path: str = "keys/public_key.pem"):
|
|
self.secret_key = secret_key # Keep for backward compatibility
|
|
self.users: Dict[str, User] = {} # user_id -> User
|
|
self.user_opensearch_clients: Dict[str, Any] = {} # user_id -> OpenSearch client
|
|
|
|
# Load RSA keys
|
|
self.private_key_path = private_key_path
|
|
self.public_key_path = public_key_path
|
|
self._load_rsa_keys()
|
|
|
|
def _load_rsa_keys(self):
|
|
"""Load RSA private and public keys"""
|
|
try:
|
|
with open(self.private_key_path, 'rb') as f:
|
|
self.private_key = serialization.load_pem_private_key(
|
|
f.read(),
|
|
password=None
|
|
)
|
|
|
|
with open(self.public_key_path, 'rb') as f:
|
|
self.public_key = serialization.load_pem_public_key(f.read())
|
|
|
|
# Also get public key in PEM format for JWKS
|
|
self.public_key_pem = open(self.public_key_path, 'r').read()
|
|
|
|
except FileNotFoundError as e:
|
|
raise Exception(f"RSA key files not found: {e}")
|
|
except Exception as e:
|
|
raise Exception(f"Failed to load RSA keys: {e}")
|
|
|
|
async def get_user_info_from_token(self, access_token: str) -> Optional[Dict[str, Any]]:
|
|
"""Get user info from Google using access token"""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://www.googleapis.com/oauth2/v2/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
print(f"Failed to get user info: {response.status_code} {response.text}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error getting user info: {e}")
|
|
return None
|
|
|
|
async def create_user_session(self, access_token: str, issuer: str) -> Optional[str]:
|
|
"""Create user session from OAuth access token"""
|
|
user_info = await self.get_user_info_from_token(access_token)
|
|
if not user_info:
|
|
return None
|
|
|
|
# Create or update user
|
|
user_id = user_info["id"]
|
|
user = User(
|
|
user_id=user_id,
|
|
email=user_info["email"],
|
|
name=user_info["name"],
|
|
picture=user_info.get("picture"),
|
|
provider="google"
|
|
)
|
|
|
|
# Update last login if user exists
|
|
if user_id in self.users:
|
|
self.users[user_id].last_login = datetime.now()
|
|
else:
|
|
self.users[user_id] = user
|
|
|
|
# Use provided issuer
|
|
|
|
# Create JWT token with OIDC-compliant claims
|
|
now = datetime.utcnow()
|
|
token_payload = {
|
|
# OIDC standard claims
|
|
"iss": issuer, # Issuer from request
|
|
"sub": user_id, # Subject (user ID)
|
|
"aud": ["opensearch", "openrag"], # Audience
|
|
"exp": now + timedelta(days=7), # Expiration
|
|
"iat": now, # Issued at
|
|
"auth_time": int(now.timestamp()), # Authentication time
|
|
|
|
# Custom claims
|
|
"user_id": user_id, # Keep for backward compatibility
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"preferred_username": user.email,
|
|
"email_verified": True,
|
|
"roles": ["openrag_user"] # Backend role for OpenSearch
|
|
}
|
|
|
|
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
|
return token
|
|
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify JWT token and return user info"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
self.public_key,
|
|
algorithms=["RS256"],
|
|
audience=["opensearch", "openrag"]
|
|
)
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
|
|
def get_user(self, user_id: str) -> Optional[User]:
|
|
"""Get user by ID"""
|
|
return self.users.get(user_id)
|
|
|
|
def get_user_from_token(self, token: str) -> Optional[User]:
|
|
"""Get user from JWT token"""
|
|
payload = self.verify_token(token)
|
|
if payload:
|
|
return self.get_user(payload["user_id"])
|
|
return None
|
|
|
|
def get_user_opensearch_client(self, user_id: str, jwt_token: str):
|
|
"""Get or create OpenSearch client for user with their JWT"""
|
|
# Check if we have a cached client for this user
|
|
if user_id not in self.user_opensearch_clients:
|
|
from config.settings import clients
|
|
self.user_opensearch_clients[user_id] = clients.create_user_opensearch_client(jwt_token)
|
|
|
|
return self.user_opensearch_clients[user_id] |