From 0fb9246e888257c65ddba5ea8e09a4a11ea9ead0 Mon Sep 17 00:00:00 2001 From: phact Date: Fri, 29 Aug 2025 00:00:06 -0400 Subject: [PATCH] connector jwt and connector cleanup fixes --- src/api/connectors.py | 38 +++++++++--------------- src/connectors/connection_manager.py | 10 +++++-- src/connectors/google_drive/connector.py | 25 ++++++++++++---- src/connectors/service.py | 8 ++--- src/main.py | 14 +++++++-- src/models/processors.py | 5 ++-- src/session_manager.py | 9 ++++-- 7 files changed, 66 insertions(+), 43 deletions(-) diff --git a/src/api/connectors.py b/src/api/connectors.py index ba511f19..b548e88e 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -20,6 +20,7 @@ async def connector_sync(request: Request, connector_service, session_manager): print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}") user = request.state.user + jwt_token = request.cookies.get("auth_token") print(f"[DEBUG] User: {user.user_id}") # Get all active connections for this connector type and user @@ -36,7 +37,7 @@ async def connector_sync(request: Request, connector_service, session_manager): task_ids = [] for connection in active_connections: print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}") - task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files) + task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files, jwt_token=jwt_token) task_ids.append(task_id) print(f"[DEBUG] Got task_id: {task_id}") @@ -154,27 +155,8 @@ async def connector_webhook(request: Request, connector_service, session_manager # Find the specific connection for this webhook connection = await connector_service.connection_manager.get_connection_by_webhook_id(channel_id) if not connection or not connection.is_active: - print(f"[WEBHOOK] Unknown channel {channel_id} - attempting to cancel old subscription") - - # Try to cancel this unknown subscription using any active connection of this connector type - try: - all_connections = await connector_service.connection_manager.list_connections( - connector_type=connector_type - ) - active_connections = [c for c in all_connections if c.is_active] - - if active_connections: - # Use the first active connection to cancel the unknown subscription - connector = await connector_service._get_connector(active_connections[0].connection_id) - if connector: - print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}") - await connector.cleanup_subscription(channel_id, resource_id) - print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}") - - except Exception as e: - print(f"[WARNING] Failed to cancel unknown subscription {channel_id}: {e}") - - return JSONResponse({"status": "cancelled_unknown", "channel_id": channel_id}) + print(f"[WEBHOOK] Unknown channel {channel_id} - no cleanup attempted (will auto-expire)") + return JSONResponse({"status": "ignored_unknown_channel", "channel_id": channel_id}) # Process webhook for the specific connection results = [] @@ -191,11 +173,19 @@ async def connector_webhook(request: Request, connector_service, session_manager if affected_files: print(f"[WEBHOOK] Connection {connection.connection_id}: {len(affected_files)} files affected") + # Generate JWT token for the user (needed for OpenSearch authentication) + user = session_manager.get_user(connection.user_id) + if user: + jwt_token = session_manager.create_jwt_token(user) + else: + jwt_token = None + # Trigger incremental sync for affected files task_id = await connector_service.sync_specific_files( connection.connection_id, connection.user_id, - affected_files + affected_files, + jwt_token=jwt_token ) result = { @@ -235,4 +225,4 @@ async def connector_webhook(request: Request, connector_service, session_manager import traceback print(f"[ERROR] Webhook processing failed: {str(e)}") traceback.print_exc() - return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500) \ No newline at end of file + return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500) diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index 19799cb3..f57df0e3 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -300,9 +300,11 @@ class ConnectionManager: print(f"[WEBHOOK] Setting up subscription for connection {connection_id}") subscription_id = await connector.setup_subscription() - # Store the subscription ID in connection config + # Store the subscription and resource IDs in connection config connection_config.config['webhook_channel_id'] = subscription_id connection_config.config['subscription_id'] = subscription_id # Alternative field + if getattr(connector, 'webhook_resource_id', None): + connection_config.config['resource_id'] = connector.webhook_resource_id # Save updated connection config await self.save_connections() @@ -327,9 +329,11 @@ class ConnectionManager: # Setup subscription subscription_id = await connector.setup_subscription() - # Store the subscription ID in connection config + # Store the subscription and resource IDs in connection config connection_config.config['webhook_channel_id'] = subscription_id connection_config.config['subscription_id'] = subscription_id + if getattr(connector, 'webhook_resource_id', None): + connection_config.config['resource_id'] = connector.webhook_resource_id # Save updated connection config await self.save_connections() @@ -338,4 +342,4 @@ class ConnectionManager: except Exception as e: print(f"[ERROR] Failed to setup webhook subscription for new connection {connection_id}: {e}") - # Don't fail the connection setup if webhook fails \ No newline at end of file + # Don't fail the connection setup if webhook fails diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py index 9fe73c42..869880d8 100644 --- a/src/connectors/google_drive/connector.py +++ b/src/connectors/google_drive/connector.py @@ -164,6 +164,8 @@ class GoogleDriveConnector(BaseConnector): self.service = None # Load existing webhook channel ID from config if available self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id') + # Load existing webhook resource ID (Google Drive requires this to stop a channel) + self.webhook_resource_id = config.get('resource_id') async def authenticate(self) -> bool: """Authenticate with Google Drive""" @@ -207,6 +209,11 @@ class GoogleDriveConnector(BaseConnector): ).execute() self.webhook_channel_id = channel_id + # Persist the resourceId returned by Google to allow proper cleanup + try: + self.webhook_resource_id = result.get('resourceId') + except Exception: + self.webhook_resource_id = None return channel_id except HttpError as e: @@ -388,6 +395,10 @@ class GoogleDriveConnector(BaseConnector): def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]: """Extract Google Drive channel ID from webhook headers""" return headers.get('x-goog-channel-id') + + def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]: + """Extract Google Drive resource ID from webhook headers""" + return headers.get('x-goog-resource-id') async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: """Handle Google Drive webhook notification""" @@ -469,15 +480,19 @@ class GoogleDriveConnector(BaseConnector): print(f"Failed to handle webhook: {e}") return [] - async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool: - """Clean up Google Drive subscription""" + async def cleanup_subscription(self, subscription_id: str) -> bool: + """Clean up Google Drive subscription for this connection. + + Uses the stored resource_id captured during subscription setup. + """ if not self._authenticated: return False try: - body = {'id': subscription_id} - if resource_id: - body['resourceId'] = resource_id + # Google Channels API requires both 'id' (channel) and 'resourceId' + if not self.webhook_resource_id: + raise ValueError("Missing resource_id for cleanup; ensure subscription state is persisted") + body = {'id': subscription_id, 'resourceId': self.webhook_resource_id} self.service.channels().stop(body=body).execute() return True diff --git a/src/connectors/service.py b/src/connectors/service.py index 69dc5179..33c287f8 100644 --- a/src/connectors/service.py +++ b/src/connectors/service.py @@ -149,7 +149,7 @@ class ConnectorService: } return mime_to_ext.get(mimetype, '.bin') - async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None) -> str: + async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None, jwt_token: str = None) -> str: """Sync files from a connector connection using existing task tracking system""" if not self.task_service: raise ValueError("TaskService not available - connector sync requires task service dependency") @@ -203,7 +203,7 @@ class ConnectorService: # Create custom processor for connector files from models.processors import ConnectorFileProcessor - processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, owner_name=owner_name, owner_email=owner_email) + processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email) # Use file IDs as items (no more fake file paths!) file_ids = [file_info['id'] for file_info in files_to_process] @@ -213,7 +213,7 @@ class ConnectorService: return task_id - async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str]) -> str: + async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str], jwt_token: str = None) -> str: """Sync specific files by their IDs (used for webhook-triggered syncs)""" if not self.task_service: raise ValueError("TaskService not available - connector sync requires task service dependency") @@ -236,7 +236,7 @@ class ConnectorService: # Create custom processor for specific connector files from models.processors import ConnectorFileProcessor # We'll pass file_ids as the files_info, the processor will handle ID-only files - processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, owner_name=owner_name, owner_email=owner_email) + processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email) # Create custom task using TaskService task_id = await self.task_service.create_custom_task(user_id, file_ids, processor) diff --git a/src/main.py b/src/main.py index 11718faa..32cc31a9 100644 --- a/src/main.py +++ b/src/main.py @@ -190,6 +190,15 @@ async def initialize_services(): # Initialize auth service auth_service = AuthService(session_manager, connector_service) + # Load persisted connector connections at startup so webhooks and syncs + # can resolve existing subscriptions immediately after server boot + try: + await connector_service.initialize() + loaded_count = len(connector_service.connection_manager.connections) + print(f"[CONNECTORS] Loaded {loaded_count} persisted connection(s) on startup") + except Exception as e: + print(f"[WARNING] Failed to load persisted connections on startup: {e}") + return { 'document_service': document_service, 'search_service': search_service, @@ -497,8 +506,7 @@ async def cleanup_subscriptions_proper(services): connector = await connector_service.get_connector(connection.connection_id) if connector: subscription_id = connection.config.get('webhook_channel_id') - resource_id = connection.config.get('resource_id') # If stored - await connector.cleanup_subscription(subscription_id, resource_id) + await connector.cleanup_subscription(subscription_id) print(f"[CLEANUP] Cancelled subscription {subscription_id}") except Exception as e: print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}") @@ -524,4 +532,4 @@ if __name__ == "__main__": host="0.0.0.0", port=8000, reload=False, # Disable reload since we're running from main - ) \ No newline at end of file + ) diff --git a/src/models/processors.py b/src/models/processors.py index d07914be..dd79b994 100644 --- a/src/models/processors.py +++ b/src/models/processors.py @@ -44,11 +44,12 @@ class DocumentFileProcessor(TaskProcessor): class ConnectorFileProcessor(TaskProcessor): """Processor for connector file uploads""" - def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, owner_name: str = None, owner_email: str = None): + def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, jwt_token: str = None, owner_name: str = None, owner_email: str = None): self.connector_service = connector_service self.connection_id = connection_id self.files_to_process = files_to_process self.user_id = user_id + self.jwt_token = jwt_token self.owner_name = owner_name self.owner_email = owner_email # Create lookup map for file info - handle both file objects and file IDs @@ -83,7 +84,7 @@ class ConnectorFileProcessor(TaskProcessor): raise ValueError("user_id not provided to ConnectorFileProcessor") # Process using existing pipeline - result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, owner_name=self.owner_name, owner_email=self.owner_email) + result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, jwt_token=self.jwt_token, owner_name=self.owner_name, owner_email=self.owner_email) file_task.status = TaskStatus.COMPLETED file_task.result = result diff --git a/src/session_manager.py b/src/session_manager.py index c69ededd..794067ad 100644 --- a/src/session_manager.py +++ b/src/session_manager.py @@ -101,6 +101,11 @@ class SessionManager: else: self.users[user_id] = user + # Create JWT token using the shared method + return self.create_jwt_token(user) + + def create_jwt_token(self, user: User) -> str: + """Create JWT token for an existing user""" # Use OpenSearch-compatible issuer for OIDC validation oidc_issuer = "http://openrag-backend:8000" @@ -109,14 +114,14 @@ class SessionManager: token_payload = { # OIDC standard claims "iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC - "sub": user_id, # Subject (user ID) + "sub": user.user_id, # Subject (user ID) "aud": ["opensearch", "openrag"], # Audience "exp": now + timedelta(days=7), # Expiration "iat": now, # Issued at "auth_time": int(now.timestamp()), # Authentication time # Custom claims - "user_id": user_id, # Keep for backward compatibility + "user_id": user.user_id, # Keep for backward compatibility "email": user.email, "name": user.name, "preferred_username": user.email,