connector jwt and connector cleanup fixes
This commit is contained in:
parent
585d71515c
commit
0fb9246e88
7 changed files with 66 additions and 43 deletions
|
|
@ -20,6 +20,7 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
|||
print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}")
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
print(f"[DEBUG] User: {user.user_id}")
|
||||
|
||||
# Get all active connections for this connector type and user
|
||||
|
|
@ -36,7 +37,7 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
|||
task_ids = []
|
||||
for connection in active_connections:
|
||||
print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}")
|
||||
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files)
|
||||
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files, jwt_token=jwt_token)
|
||||
task_ids.append(task_id)
|
||||
print(f"[DEBUG] Got task_id: {task_id}")
|
||||
|
||||
|
|
@ -154,27 +155,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
# Find the specific connection for this webhook
|
||||
connection = await connector_service.connection_manager.get_connection_by_webhook_id(channel_id)
|
||||
if not connection or not connection.is_active:
|
||||
print(f"[WEBHOOK] Unknown channel {channel_id} - attempting to cancel old subscription")
|
||||
|
||||
# Try to cancel this unknown subscription using any active connection of this connector type
|
||||
try:
|
||||
all_connections = await connector_service.connection_manager.list_connections(
|
||||
connector_type=connector_type
|
||||
)
|
||||
active_connections = [c for c in all_connections if c.is_active]
|
||||
|
||||
if active_connections:
|
||||
# Use the first active connection to cancel the unknown subscription
|
||||
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
||||
if connector:
|
||||
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
||||
await connector.cleanup_subscription(channel_id, resource_id)
|
||||
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Failed to cancel unknown subscription {channel_id}: {e}")
|
||||
|
||||
return JSONResponse({"status": "cancelled_unknown", "channel_id": channel_id})
|
||||
print(f"[WEBHOOK] Unknown channel {channel_id} - no cleanup attempted (will auto-expire)")
|
||||
return JSONResponse({"status": "ignored_unknown_channel", "channel_id": channel_id})
|
||||
|
||||
# Process webhook for the specific connection
|
||||
results = []
|
||||
|
|
@ -191,11 +173,19 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
if affected_files:
|
||||
print(f"[WEBHOOK] Connection {connection.connection_id}: {len(affected_files)} files affected")
|
||||
|
||||
# Generate JWT token for the user (needed for OpenSearch authentication)
|
||||
user = session_manager.get_user(connection.user_id)
|
||||
if user:
|
||||
jwt_token = session_manager.create_jwt_token(user)
|
||||
else:
|
||||
jwt_token = None
|
||||
|
||||
# Trigger incremental sync for affected files
|
||||
task_id = await connector_service.sync_specific_files(
|
||||
connection.connection_id,
|
||||
connection.user_id,
|
||||
affected_files
|
||||
affected_files,
|
||||
jwt_token=jwt_token
|
||||
)
|
||||
|
||||
result = {
|
||||
|
|
@ -235,4 +225,4 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
import traceback
|
||||
print(f"[ERROR] Webhook processing failed: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500)
|
||||
return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500)
|
||||
|
|
|
|||
|
|
@ -300,9 +300,11 @@ class ConnectionManager:
|
|||
print(f"[WEBHOOK] Setting up subscription for connection {connection_id}")
|
||||
subscription_id = await connector.setup_subscription()
|
||||
|
||||
# Store the subscription ID in connection config
|
||||
# Store the subscription and resource IDs in connection config
|
||||
connection_config.config['webhook_channel_id'] = subscription_id
|
||||
connection_config.config['subscription_id'] = subscription_id # Alternative field
|
||||
if getattr(connector, 'webhook_resource_id', None):
|
||||
connection_config.config['resource_id'] = connector.webhook_resource_id
|
||||
|
||||
# Save updated connection config
|
||||
await self.save_connections()
|
||||
|
|
@ -327,9 +329,11 @@ class ConnectionManager:
|
|||
# Setup subscription
|
||||
subscription_id = await connector.setup_subscription()
|
||||
|
||||
# Store the subscription ID in connection config
|
||||
# Store the subscription and resource IDs in connection config
|
||||
connection_config.config['webhook_channel_id'] = subscription_id
|
||||
connection_config.config['subscription_id'] = subscription_id
|
||||
if getattr(connector, 'webhook_resource_id', None):
|
||||
connection_config.config['resource_id'] = connector.webhook_resource_id
|
||||
|
||||
# Save updated connection config
|
||||
await self.save_connections()
|
||||
|
|
@ -338,4 +342,4 @@ class ConnectionManager:
|
|||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to setup webhook subscription for new connection {connection_id}: {e}")
|
||||
# Don't fail the connection setup if webhook fails
|
||||
# Don't fail the connection setup if webhook fails
|
||||
|
|
|
|||
|
|
@ -164,6 +164,8 @@ class GoogleDriveConnector(BaseConnector):
|
|||
self.service = None
|
||||
# Load existing webhook channel ID from config if available
|
||||
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
|
||||
# Load existing webhook resource ID (Google Drive requires this to stop a channel)
|
||||
self.webhook_resource_id = config.get('resource_id')
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
"""Authenticate with Google Drive"""
|
||||
|
|
@ -207,6 +209,11 @@ class GoogleDriveConnector(BaseConnector):
|
|||
).execute()
|
||||
|
||||
self.webhook_channel_id = channel_id
|
||||
# Persist the resourceId returned by Google to allow proper cleanup
|
||||
try:
|
||||
self.webhook_resource_id = result.get('resourceId')
|
||||
except Exception:
|
||||
self.webhook_resource_id = None
|
||||
return channel_id
|
||||
|
||||
except HttpError as e:
|
||||
|
|
@ -388,6 +395,10 @@ class GoogleDriveConnector(BaseConnector):
|
|||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract Google Drive channel ID from webhook headers"""
|
||||
return headers.get('x-goog-channel-id')
|
||||
|
||||
def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract Google Drive resource ID from webhook headers"""
|
||||
return headers.get('x-goog-resource-id')
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
"""Handle Google Drive webhook notification"""
|
||||
|
|
@ -469,15 +480,19 @@ class GoogleDriveConnector(BaseConnector):
|
|||
print(f"Failed to handle webhook: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||
"""Clean up Google Drive subscription"""
|
||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||
"""Clean up Google Drive subscription for this connection.
|
||||
|
||||
Uses the stored resource_id captured during subscription setup.
|
||||
"""
|
||||
if not self._authenticated:
|
||||
return False
|
||||
|
||||
try:
|
||||
body = {'id': subscription_id}
|
||||
if resource_id:
|
||||
body['resourceId'] = resource_id
|
||||
# Google Channels API requires both 'id' (channel) and 'resourceId'
|
||||
if not self.webhook_resource_id:
|
||||
raise ValueError("Missing resource_id for cleanup; ensure subscription state is persisted")
|
||||
body = {'id': subscription_id, 'resourceId': self.webhook_resource_id}
|
||||
|
||||
self.service.channels().stop(body=body).execute()
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class ConnectorService:
|
|||
}
|
||||
return mime_to_ext.get(mimetype, '.bin')
|
||||
|
||||
async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None) -> str:
|
||||
async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None, jwt_token: str = None) -> str:
|
||||
"""Sync files from a connector connection using existing task tracking system"""
|
||||
if not self.task_service:
|
||||
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||
|
|
@ -203,7 +203,7 @@ class ConnectorService:
|
|||
|
||||
# Create custom processor for connector files
|
||||
from models.processors import ConnectorFileProcessor
|
||||
processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, owner_name=owner_name, owner_email=owner_email)
|
||||
processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email)
|
||||
|
||||
# Use file IDs as items (no more fake file paths!)
|
||||
file_ids = [file_info['id'] for file_info in files_to_process]
|
||||
|
|
@ -213,7 +213,7 @@ class ConnectorService:
|
|||
|
||||
return task_id
|
||||
|
||||
async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str]) -> str:
|
||||
async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str], jwt_token: str = None) -> str:
|
||||
"""Sync specific files by their IDs (used for webhook-triggered syncs)"""
|
||||
if not self.task_service:
|
||||
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||
|
|
@ -236,7 +236,7 @@ class ConnectorService:
|
|||
# Create custom processor for specific connector files
|
||||
from models.processors import ConnectorFileProcessor
|
||||
# We'll pass file_ids as the files_info, the processor will handle ID-only files
|
||||
processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, owner_name=owner_name, owner_email=owner_email)
|
||||
processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email)
|
||||
|
||||
# Create custom task using TaskService
|
||||
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
||||
|
|
|
|||
14
src/main.py
14
src/main.py
|
|
@ -190,6 +190,15 @@ async def initialize_services():
|
|||
# Initialize auth service
|
||||
auth_service = AuthService(session_manager, connector_service)
|
||||
|
||||
# Load persisted connector connections at startup so webhooks and syncs
|
||||
# can resolve existing subscriptions immediately after server boot
|
||||
try:
|
||||
await connector_service.initialize()
|
||||
loaded_count = len(connector_service.connection_manager.connections)
|
||||
print(f"[CONNECTORS] Loaded {loaded_count} persisted connection(s) on startup")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Failed to load persisted connections on startup: {e}")
|
||||
|
||||
return {
|
||||
'document_service': document_service,
|
||||
'search_service': search_service,
|
||||
|
|
@ -497,8 +506,7 @@ async def cleanup_subscriptions_proper(services):
|
|||
connector = await connector_service.get_connector(connection.connection_id)
|
||||
if connector:
|
||||
subscription_id = connection.config.get('webhook_channel_id')
|
||||
resource_id = connection.config.get('resource_id') # If stored
|
||||
await connector.cleanup_subscription(subscription_id, resource_id)
|
||||
await connector.cleanup_subscription(subscription_id)
|
||||
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
|
||||
|
|
@ -524,4 +532,4 @@ if __name__ == "__main__":
|
|||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False, # Disable reload since we're running from main
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,11 +44,12 @@ class DocumentFileProcessor(TaskProcessor):
|
|||
class ConnectorFileProcessor(TaskProcessor):
|
||||
"""Processor for connector file uploads"""
|
||||
|
||||
def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, owner_name: str = None, owner_email: str = None):
|
||||
def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, jwt_token: str = None, owner_name: str = None, owner_email: str = None):
|
||||
self.connector_service = connector_service
|
||||
self.connection_id = connection_id
|
||||
self.files_to_process = files_to_process
|
||||
self.user_id = user_id
|
||||
self.jwt_token = jwt_token
|
||||
self.owner_name = owner_name
|
||||
self.owner_email = owner_email
|
||||
# Create lookup map for file info - handle both file objects and file IDs
|
||||
|
|
@ -83,7 +84,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
||||
|
||||
# Process using existing pipeline
|
||||
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, owner_name=self.owner_name, owner_email=self.owner_email)
|
||||
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, jwt_token=self.jwt_token, owner_name=self.owner_name, owner_email=self.owner_email)
|
||||
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
|
|
|
|||
|
|
@ -101,6 +101,11 @@ class SessionManager:
|
|||
else:
|
||||
self.users[user_id] = user
|
||||
|
||||
# Create JWT token using the shared method
|
||||
return self.create_jwt_token(user)
|
||||
|
||||
def create_jwt_token(self, user: User) -> str:
|
||||
"""Create JWT token for an existing user"""
|
||||
# Use OpenSearch-compatible issuer for OIDC validation
|
||||
oidc_issuer = "http://openrag-backend:8000"
|
||||
|
||||
|
|
@ -109,14 +114,14 @@ class SessionManager:
|
|||
token_payload = {
|
||||
# OIDC standard claims
|
||||
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
||||
"sub": user_id, # Subject (user ID)
|
||||
"sub": user.user_id, # Subject (user ID)
|
||||
"aud": ["opensearch", "openrag"], # Audience
|
||||
"exp": now + timedelta(days=7), # Expiration
|
||||
"iat": now, # Issued at
|
||||
"auth_time": int(now.timestamp()), # Authentication time
|
||||
|
||||
# Custom claims
|
||||
"user_id": user_id, # Keep for backward compatibility
|
||||
"user_id": user.user_id, # Keep for backward compatibility
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"preferred_username": user.email,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue