openrag/src/services/auth_service.py
2025-12-12 15:13:11 -05:00

411 lines
16 KiB
Python

import os
import uuid
import json
import httpx
import aiofiles
from datetime import datetime, timedelta
from typing import Optional
import asyncio
from config.settings import WEBHOOK_BASE_URL, is_no_auth_mode
from session_manager import SessionManager
from services.langflow_mcp_service import LangflowMCPService
from connectors.google_drive.oauth import GoogleDriveOAuth
from connectors.onedrive.oauth import OneDriveOAuth
from connectors.sharepoint.oauth import SharePointOAuth
from connectors.google_drive import GoogleDriveConnector
from connectors.onedrive import OneDriveConnector
from connectors.sharepoint import SharePointConnector
class AuthService:
def __init__(self, session_manager: SessionManager, connector_service=None, langflow_mcp_service: LangflowMCPService | None = None):
self.session_manager = session_manager
self.connector_service = connector_service
self.used_auth_codes = set() # Track used authorization codes
self.langflow_mcp_service = langflow_mcp_service
self._background_tasks = set()
async def init_oauth(
self,
connector_type: str,
purpose: str,
connection_name: str,
redirect_uri: str,
user_id: str = None,
) -> dict:
"""Initialize OAuth flow for authentication or data source connection"""
# Check if we're in no-auth mode
if is_no_auth_mode():
if purpose == "app_auth":
raise ValueError(
"OAuth credentials not configured. Please add GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET environment variables to enable authentication."
)
else:
raise ValueError(
"OAuth credentials not configured. Data source connections require OAuth setup."
)
# Validate connector_type based on purpose
if purpose == "app_auth" and connector_type != "google_drive":
raise ValueError("Only Google login supported for app authentication")
elif purpose == "data_source" and connector_type not in [
"google_drive",
"onedrive",
"sharepoint",
]:
raise ValueError(f"Unsupported connector type: {connector_type}")
elif purpose not in ["app_auth", "data_source"]:
raise ValueError(f"Unsupported purpose: {purpose}")
if not redirect_uri:
raise ValueError("redirect_uri is required")
# We'll validate client credentials when creating the connector
# Create connection configuration - use data/ directory for persistence
token_file = f"data/{connector_type}_{purpose}_{uuid.uuid4().hex[:8]}.json"
config = {
"token_file": token_file,
"connector_type": connector_type,
"purpose": purpose,
"redirect_uri": redirect_uri,
}
# Only add webhook URL if WEBHOOK_BASE_URL is configured
if WEBHOOK_BASE_URL:
config["webhook_url"] = (
f"{WEBHOOK_BASE_URL}/connectors/{connector_type}/webhook"
)
# Create connection in manager
connection_id = (
await self.connector_service.connection_manager.create_connection(
connector_type=connector_type,
name=connection_name,
config=config,
user_id=user_id,
)
)
# Get OAuth configuration from connector and OAuth classes
import os
# Map connector types to their connector and OAuth classes
connector_class_map = {
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
"onedrive": (OneDriveConnector, OneDriveOAuth),
"sharepoint": (SharePointConnector, SharePointOAuth),
}
connector_class, oauth_class = connector_class_map.get(
connector_type, (None, None)
)
if not connector_class or not oauth_class:
raise ValueError(f"No classes found for connector type: {connector_type}")
# Get scopes from OAuth class
scopes = oauth_class.SCOPES
# Get endpoints from OAuth class
auth_endpoint = oauth_class.AUTH_ENDPOINT
token_endpoint = oauth_class.TOKEN_ENDPOINT
# src/services/auth_service.py
client_key = getattr(connector_class, "CLIENT_ID_ENV_VAR", None)
secret_key = getattr(connector_class, "CLIENT_SECRET_ENV_VAR", None)
def _assert_env_key(name, val):
if not isinstance(val, str) or not val.strip():
raise RuntimeError(
f"{connector_class.__name__} misconfigured: {name} must be a non-empty string "
f"(got {val!r}). Define it as a class attribute on the connector."
)
_assert_env_key("CLIENT_ID_ENV_VAR", client_key)
_assert_env_key("CLIENT_SECRET_ENV_VAR", secret_key)
client_id = os.getenv(client_key)
client_secret = os.getenv(secret_key)
if not client_id or not client_secret:
raise RuntimeError(
f"Missing OAuth env vars for {connector_class.__name__}. "
f"Set {client_key} and {secret_key} in the environment."
)
oauth_config = {
"client_id": client_id,
"scopes": scopes,
"redirect_uri": redirect_uri,
"authorization_endpoint": auth_endpoint,
"token_endpoint": token_endpoint,
}
return {"connection_id": connection_id, "oauth_config": oauth_config}
async def handle_oauth_callback(
self,
connection_id: str,
authorization_code: str,
state: str = None,
request=None,
) -> dict:
"""Handle OAuth callback - exchange authorization code for tokens"""
if not all([connection_id, authorization_code]):
raise ValueError(
"Missing required parameters (connection_id, authorization_code)"
)
# Check if authorization code has already been used
if authorization_code in self.used_auth_codes:
raise ValueError("Authorization code already used")
# Mark code as used to prevent duplicate requests
self.used_auth_codes.add(authorization_code)
try:
# Get connection config
connection_config = (
await self.connector_service.connection_manager.get_connection(
connection_id
)
)
if not connection_config:
raise ValueError("Connection not found")
# Exchange authorization code for tokens
redirect_uri = connection_config.config.get("redirect_uri")
if not redirect_uri:
raise ValueError("Redirect URI not found in connection config")
# Get connector to access client credentials and endpoints
connector = self.connector_service.connection_manager._create_connector(
connection_config
)
# Get token endpoint from connector type
connector_type = connection_config.connector_type
connector_class_map = {
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
"onedrive": (OneDriveConnector, OneDriveOAuth),
"sharepoint": (SharePointConnector, SharePointOAuth),
}
connector_class, oauth_class = connector_class_map.get(
connector_type, (None, None)
)
if not connector_class or not oauth_class:
raise ValueError(
f"No classes found for connector type: {connector_type}"
)
token_url = oauth_class.TOKEN_ENDPOINT
token_payload = {
"code": authorization_code,
"client_id": connector.get_client_id(),
"client_secret": connector.get_client_secret(),
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
async with httpx.AsyncClient() as client:
token_response = await client.post(token_url, data=token_payload)
if token_response.status_code != 200:
raise Exception(f"Token exchange failed: {token_response.text}")
token_data = token_response.json()
# Store tokens in the token file (without client_secret)
# Use actual scopes from OAuth response
granted_scopes = token_data.get("scope")
if not granted_scopes:
raise ValueError(
f"OAuth provider for {connector_type} did not return granted scopes in token response"
)
# OAuth providers typically return scopes as a space-separated string
scopes = (
granted_scopes.split(" ")
if isinstance(granted_scopes, str)
else granted_scopes
)
token_file_data = {
"token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
"scopes": scopes,
}
# Add expiry if provided
if token_data.get("expires_in"):
expiry = datetime.now() + timedelta(
seconds=int(token_data["expires_in"])
)
token_file_data["expiry"] = expiry.isoformat()
# Save tokens to file
token_file_path = connection_config.config["token_file"]
async with aiofiles.open(token_file_path, "w") as f:
await f.write(json.dumps(token_file_data, indent=2))
# Route based on purpose
purpose = connection_config.config.get("purpose", "data_source")
if purpose == "app_auth":
return await self._handle_app_auth(
connection_id, connection_config, token_data, request
)
else:
return await self._handle_data_source_auth(
connection_id, connection_config
)
except Exception as e:
# Remove used code from set if we failed
self.used_auth_codes.discard(authorization_code)
raise e
async def _handle_app_auth(
self, connection_id: str, connection_config, token_data: dict, request=None
) -> dict:
"""Handle app authentication - create user session"""
# Extract issuer from redirect_uri in connection config
redirect_uri = connection_config.config.get("redirect_uri")
if not redirect_uri:
raise ValueError("redirect_uri not found in connection config")
# Get base URL from redirect_uri (remove path)
from urllib.parse import urlparse
parsed = urlparse(redirect_uri)
issuer = f"{parsed.scheme}://{parsed.netloc}"
jwt_token = await self.session_manager.create_user_session(
token_data["access_token"], issuer
)
if jwt_token:
# Get the user info to create a persistent connector connection
user_info = await self.session_manager.get_user_info_from_token(
token_data["access_token"]
)
# Best-effort: update Langflow MCP servers to include user's JWT and owner headers
try:
if self.langflow_mcp_service and isinstance(jwt_token, str) and jwt_token.strip():
global_vars = {"JWT": jwt_token}
global_vars["CONNECTOR_TYPE_URL"] = "url"
if user_info:
if user_info.get("id"):
global_vars["OWNER"] = user_info.get("id")
if user_info.get("name"):
# OWNER_NAME may contain spaces, which can cause issues in headers.
# Alternative: URL-encode the owner name to preserve spaces and special characters.
owner_name = user_info.get("name")
if owner_name:
global_vars["OWNER_NAME"] = str(f"\"{owner_name}\"")
if user_info.get("email"):
global_vars["OWNER_EMAIL"] = user_info.get("email")
# Add provider credentials to MCP servers using utility function
from config.settings import get_openrag_config
from utils.langflow_headers import build_mcp_global_vars_from_config
config = get_openrag_config()
provider_vars = build_mcp_global_vars_from_config(config)
# Merge provider credentials with user info
global_vars.update(provider_vars)
# Run in background to avoid delaying login flow
task = asyncio.create_task(
self.langflow_mcp_service.update_mcp_servers_with_global_vars(global_vars)
)
# Keep reference until done to avoid premature GC
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
except Exception:
# Do not block login on MCP update issues
pass
response_data = {
"status": "authenticated",
"purpose": "app_auth",
"redirect": "/",
"jwt_token": jwt_token, # Include JWT token in response
}
if user_info and user_info.get("id"):
# Convert the temporary auth connection to a persistent OAuth connection
await self.connector_service.connection_manager.update_connection(
connection_id=connection_id,
connector_type="google_drive",
name=f"Google Drive ({user_info.get('email', 'Unknown')})",
user_id=user_info.get("id"),
config={
**connection_config.config,
"purpose": "data_source",
"user_email": user_info.get("email"),
**(
{
"webhook_url": f"{WEBHOOK_BASE_URL}/connectors/google_drive/webhook"
}
if WEBHOOK_BASE_URL
else {}
),
},
)
response_data["google_drive_connection_id"] = connection_id
else:
# Fallback: delete connection if we can't get user info
await self.connector_service.connection_manager.delete_connection(
connection_id
)
return response_data
else:
# Clean up connection if session creation failed
await self.connector_service.connection_manager.delete_connection(
connection_id
)
raise Exception("Failed to create user session")
async def _handle_data_source_auth(
self, connection_id: str, connection_config
) -> dict:
"""Handle data source connection - keep the connection for syncing"""
return {
"status": "authenticated",
"connection_id": connection_id,
"purpose": "data_source",
"connector_type": connection_config.connector_type,
}
async def get_user_info(self, request) -> Optional[dict]:
"""Get current user information from request"""
# In no-auth mode, return a consistent response
if is_no_auth_mode():
return {"authenticated": False, "user": None, "no_auth_mode": True}
user = getattr(request.state, "user", None)
if user:
user_data = {
"authenticated": True,
"user": {
"user_id": user.user_id,
"email": user.email,
"name": user.name,
"picture": user.picture,
"provider": user.provider,
"last_login": user.last_login.isoformat()
if user.last_login
else None,
},
}
return user_data
else:
return {"authenticated": False, "user": None}