connector jwt and connector cleanup fixes
This commit is contained in:
parent
585d71515c
commit
0fb9246e88
7 changed files with 66 additions and 43 deletions
|
|
@ -20,6 +20,7 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
||||||
print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}")
|
print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}")
|
||||||
|
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
print(f"[DEBUG] User: {user.user_id}")
|
print(f"[DEBUG] User: {user.user_id}")
|
||||||
|
|
||||||
# Get all active connections for this connector type and user
|
# Get all active connections for this connector type and user
|
||||||
|
|
@ -36,7 +37,7 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
||||||
task_ids = []
|
task_ids = []
|
||||||
for connection in active_connections:
|
for connection in active_connections:
|
||||||
print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}")
|
print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}")
|
||||||
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files)
|
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files, jwt_token=jwt_token)
|
||||||
task_ids.append(task_id)
|
task_ids.append(task_id)
|
||||||
print(f"[DEBUG] Got task_id: {task_id}")
|
print(f"[DEBUG] Got task_id: {task_id}")
|
||||||
|
|
||||||
|
|
@ -154,27 +155,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
# Find the specific connection for this webhook
|
# Find the specific connection for this webhook
|
||||||
connection = await connector_service.connection_manager.get_connection_by_webhook_id(channel_id)
|
connection = await connector_service.connection_manager.get_connection_by_webhook_id(channel_id)
|
||||||
if not connection or not connection.is_active:
|
if not connection or not connection.is_active:
|
||||||
print(f"[WEBHOOK] Unknown channel {channel_id} - attempting to cancel old subscription")
|
print(f"[WEBHOOK] Unknown channel {channel_id} - no cleanup attempted (will auto-expire)")
|
||||||
|
return JSONResponse({"status": "ignored_unknown_channel", "channel_id": channel_id})
|
||||||
# Try to cancel this unknown subscription using any active connection of this connector type
|
|
||||||
try:
|
|
||||||
all_connections = await connector_service.connection_manager.list_connections(
|
|
||||||
connector_type=connector_type
|
|
||||||
)
|
|
||||||
active_connections = [c for c in all_connections if c.is_active]
|
|
||||||
|
|
||||||
if active_connections:
|
|
||||||
# Use the first active connection to cancel the unknown subscription
|
|
||||||
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
|
||||||
if connector:
|
|
||||||
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
|
||||||
await connector.cleanup_subscription(channel_id, resource_id)
|
|
||||||
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WARNING] Failed to cancel unknown subscription {channel_id}: {e}")
|
|
||||||
|
|
||||||
return JSONResponse({"status": "cancelled_unknown", "channel_id": channel_id})
|
|
||||||
|
|
||||||
# Process webhook for the specific connection
|
# Process webhook for the specific connection
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -191,11 +173,19 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
if affected_files:
|
if affected_files:
|
||||||
print(f"[WEBHOOK] Connection {connection.connection_id}: {len(affected_files)} files affected")
|
print(f"[WEBHOOK] Connection {connection.connection_id}: {len(affected_files)} files affected")
|
||||||
|
|
||||||
|
# Generate JWT token for the user (needed for OpenSearch authentication)
|
||||||
|
user = session_manager.get_user(connection.user_id)
|
||||||
|
if user:
|
||||||
|
jwt_token = session_manager.create_jwt_token(user)
|
||||||
|
else:
|
||||||
|
jwt_token = None
|
||||||
|
|
||||||
# Trigger incremental sync for affected files
|
# Trigger incremental sync for affected files
|
||||||
task_id = await connector_service.sync_specific_files(
|
task_id = await connector_service.sync_specific_files(
|
||||||
connection.connection_id,
|
connection.connection_id,
|
||||||
connection.user_id,
|
connection.user_id,
|
||||||
affected_files
|
affected_files,
|
||||||
|
jwt_token=jwt_token
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
|
@ -235,4 +225,4 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
import traceback
|
import traceback
|
||||||
print(f"[ERROR] Webhook processing failed: {str(e)}")
|
print(f"[ERROR] Webhook processing failed: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500)
|
return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500)
|
||||||
|
|
|
||||||
|
|
@ -300,9 +300,11 @@ class ConnectionManager:
|
||||||
print(f"[WEBHOOK] Setting up subscription for connection {connection_id}")
|
print(f"[WEBHOOK] Setting up subscription for connection {connection_id}")
|
||||||
subscription_id = await connector.setup_subscription()
|
subscription_id = await connector.setup_subscription()
|
||||||
|
|
||||||
# Store the subscription ID in connection config
|
# Store the subscription and resource IDs in connection config
|
||||||
connection_config.config['webhook_channel_id'] = subscription_id
|
connection_config.config['webhook_channel_id'] = subscription_id
|
||||||
connection_config.config['subscription_id'] = subscription_id # Alternative field
|
connection_config.config['subscription_id'] = subscription_id # Alternative field
|
||||||
|
if getattr(connector, 'webhook_resource_id', None):
|
||||||
|
connection_config.config['resource_id'] = connector.webhook_resource_id
|
||||||
|
|
||||||
# Save updated connection config
|
# Save updated connection config
|
||||||
await self.save_connections()
|
await self.save_connections()
|
||||||
|
|
@ -327,9 +329,11 @@ class ConnectionManager:
|
||||||
# Setup subscription
|
# Setup subscription
|
||||||
subscription_id = await connector.setup_subscription()
|
subscription_id = await connector.setup_subscription()
|
||||||
|
|
||||||
# Store the subscription ID in connection config
|
# Store the subscription and resource IDs in connection config
|
||||||
connection_config.config['webhook_channel_id'] = subscription_id
|
connection_config.config['webhook_channel_id'] = subscription_id
|
||||||
connection_config.config['subscription_id'] = subscription_id
|
connection_config.config['subscription_id'] = subscription_id
|
||||||
|
if getattr(connector, 'webhook_resource_id', None):
|
||||||
|
connection_config.config['resource_id'] = connector.webhook_resource_id
|
||||||
|
|
||||||
# Save updated connection config
|
# Save updated connection config
|
||||||
await self.save_connections()
|
await self.save_connections()
|
||||||
|
|
@ -338,4 +342,4 @@ class ConnectionManager:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to setup webhook subscription for new connection {connection_id}: {e}")
|
print(f"[ERROR] Failed to setup webhook subscription for new connection {connection_id}: {e}")
|
||||||
# Don't fail the connection setup if webhook fails
|
# Don't fail the connection setup if webhook fails
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,8 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
self.service = None
|
self.service = None
|
||||||
# Load existing webhook channel ID from config if available
|
# Load existing webhook channel ID from config if available
|
||||||
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
|
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
|
||||||
|
# Load existing webhook resource ID (Google Drive requires this to stop a channel)
|
||||||
|
self.webhook_resource_id = config.get('resource_id')
|
||||||
|
|
||||||
async def authenticate(self) -> bool:
|
async def authenticate(self) -> bool:
|
||||||
"""Authenticate with Google Drive"""
|
"""Authenticate with Google Drive"""
|
||||||
|
|
@ -207,6 +209,11 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
self.webhook_channel_id = channel_id
|
self.webhook_channel_id = channel_id
|
||||||
|
# Persist the resourceId returned by Google to allow proper cleanup
|
||||||
|
try:
|
||||||
|
self.webhook_resource_id = result.get('resourceId')
|
||||||
|
except Exception:
|
||||||
|
self.webhook_resource_id = None
|
||||||
return channel_id
|
return channel_id
|
||||||
|
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
|
|
@ -388,6 +395,10 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||||
"""Extract Google Drive channel ID from webhook headers"""
|
"""Extract Google Drive channel ID from webhook headers"""
|
||||||
return headers.get('x-goog-channel-id')
|
return headers.get('x-goog-channel-id')
|
||||||
|
|
||||||
|
def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract Google Drive resource ID from webhook headers"""
|
||||||
|
return headers.get('x-goog-resource-id')
|
||||||
|
|
||||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
"""Handle Google Drive webhook notification"""
|
"""Handle Google Drive webhook notification"""
|
||||||
|
|
@ -469,15 +480,19 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
print(f"Failed to handle webhook: {e}")
|
print(f"Failed to handle webhook: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||||
"""Clean up Google Drive subscription"""
|
"""Clean up Google Drive subscription for this connection.
|
||||||
|
|
||||||
|
Uses the stored resource_id captured during subscription setup.
|
||||||
|
"""
|
||||||
if not self._authenticated:
|
if not self._authenticated:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = {'id': subscription_id}
|
# Google Channels API requires both 'id' (channel) and 'resourceId'
|
||||||
if resource_id:
|
if not self.webhook_resource_id:
|
||||||
body['resourceId'] = resource_id
|
raise ValueError("Missing resource_id for cleanup; ensure subscription state is persisted")
|
||||||
|
body = {'id': subscription_id, 'resourceId': self.webhook_resource_id}
|
||||||
|
|
||||||
self.service.channels().stop(body=body).execute()
|
self.service.channels().stop(body=body).execute()
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ class ConnectorService:
|
||||||
}
|
}
|
||||||
return mime_to_ext.get(mimetype, '.bin')
|
return mime_to_ext.get(mimetype, '.bin')
|
||||||
|
|
||||||
async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None) -> str:
|
async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None, jwt_token: str = None) -> str:
|
||||||
"""Sync files from a connector connection using existing task tracking system"""
|
"""Sync files from a connector connection using existing task tracking system"""
|
||||||
if not self.task_service:
|
if not self.task_service:
|
||||||
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||||
|
|
@ -203,7 +203,7 @@ class ConnectorService:
|
||||||
|
|
||||||
# Create custom processor for connector files
|
# Create custom processor for connector files
|
||||||
from models.processors import ConnectorFileProcessor
|
from models.processors import ConnectorFileProcessor
|
||||||
processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, owner_name=owner_name, owner_email=owner_email)
|
processor = ConnectorFileProcessor(self, connection_id, files_to_process, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email)
|
||||||
|
|
||||||
# Use file IDs as items (no more fake file paths!)
|
# Use file IDs as items (no more fake file paths!)
|
||||||
file_ids = [file_info['id'] for file_info in files_to_process]
|
file_ids = [file_info['id'] for file_info in files_to_process]
|
||||||
|
|
@ -213,7 +213,7 @@ class ConnectorService:
|
||||||
|
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str]) -> str:
|
async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str], jwt_token: str = None) -> str:
|
||||||
"""Sync specific files by their IDs (used for webhook-triggered syncs)"""
|
"""Sync specific files by their IDs (used for webhook-triggered syncs)"""
|
||||||
if not self.task_service:
|
if not self.task_service:
|
||||||
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||||
|
|
@ -236,7 +236,7 @@ class ConnectorService:
|
||||||
# Create custom processor for specific connector files
|
# Create custom processor for specific connector files
|
||||||
from models.processors import ConnectorFileProcessor
|
from models.processors import ConnectorFileProcessor
|
||||||
# We'll pass file_ids as the files_info, the processor will handle ID-only files
|
# We'll pass file_ids as the files_info, the processor will handle ID-only files
|
||||||
processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, owner_name=owner_name, owner_email=owner_email)
|
processor = ConnectorFileProcessor(self, connection_id, file_ids, user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email)
|
||||||
|
|
||||||
# Create custom task using TaskService
|
# Create custom task using TaskService
|
||||||
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
||||||
|
|
|
||||||
14
src/main.py
14
src/main.py
|
|
@ -190,6 +190,15 @@ async def initialize_services():
|
||||||
# Initialize auth service
|
# Initialize auth service
|
||||||
auth_service = AuthService(session_manager, connector_service)
|
auth_service = AuthService(session_manager, connector_service)
|
||||||
|
|
||||||
|
# Load persisted connector connections at startup so webhooks and syncs
|
||||||
|
# can resolve existing subscriptions immediately after server boot
|
||||||
|
try:
|
||||||
|
await connector_service.initialize()
|
||||||
|
loaded_count = len(connector_service.connection_manager.connections)
|
||||||
|
print(f"[CONNECTORS] Loaded {loaded_count} persisted connection(s) on startup")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] Failed to load persisted connections on startup: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'document_service': document_service,
|
'document_service': document_service,
|
||||||
'search_service': search_service,
|
'search_service': search_service,
|
||||||
|
|
@ -497,8 +506,7 @@ async def cleanup_subscriptions_proper(services):
|
||||||
connector = await connector_service.get_connector(connection.connection_id)
|
connector = await connector_service.get_connector(connection.connection_id)
|
||||||
if connector:
|
if connector:
|
||||||
subscription_id = connection.config.get('webhook_channel_id')
|
subscription_id = connection.config.get('webhook_channel_id')
|
||||||
resource_id = connection.config.get('resource_id') # If stored
|
await connector.cleanup_subscription(subscription_id)
|
||||||
await connector.cleanup_subscription(subscription_id, resource_id)
|
|
||||||
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
|
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
|
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
|
||||||
|
|
@ -524,4 +532,4 @@ if __name__ == "__main__":
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=8000,
|
||||||
reload=False, # Disable reload since we're running from main
|
reload=False, # Disable reload since we're running from main
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -44,11 +44,12 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
class ConnectorFileProcessor(TaskProcessor):
|
class ConnectorFileProcessor(TaskProcessor):
|
||||||
"""Processor for connector file uploads"""
|
"""Processor for connector file uploads"""
|
||||||
|
|
||||||
def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, owner_name: str = None, owner_email: str = None):
|
def __init__(self, connector_service, connection_id: str, files_to_process: list, user_id: str = None, jwt_token: str = None, owner_name: str = None, owner_email: str = None):
|
||||||
self.connector_service = connector_service
|
self.connector_service = connector_service
|
||||||
self.connection_id = connection_id
|
self.connection_id = connection_id
|
||||||
self.files_to_process = files_to_process
|
self.files_to_process = files_to_process
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
# Create lookup map for file info - handle both file objects and file IDs
|
# Create lookup map for file info - handle both file objects and file IDs
|
||||||
|
|
@ -83,7 +84,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
||||||
|
|
||||||
# Process using existing pipeline
|
# Process using existing pipeline
|
||||||
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, owner_name=self.owner_name, owner_email=self.owner_email)
|
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type, jwt_token=self.jwt_token, owner_name=self.owner_name, owner_email=self.owner_email)
|
||||||
|
|
||||||
file_task.status = TaskStatus.COMPLETED
|
file_task.status = TaskStatus.COMPLETED
|
||||||
file_task.result = result
|
file_task.result = result
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,11 @@ class SessionManager:
|
||||||
else:
|
else:
|
||||||
self.users[user_id] = user
|
self.users[user_id] = user
|
||||||
|
|
||||||
|
# Create JWT token using the shared method
|
||||||
|
return self.create_jwt_token(user)
|
||||||
|
|
||||||
|
def create_jwt_token(self, user: User) -> str:
|
||||||
|
"""Create JWT token for an existing user"""
|
||||||
# Use OpenSearch-compatible issuer for OIDC validation
|
# Use OpenSearch-compatible issuer for OIDC validation
|
||||||
oidc_issuer = "http://openrag-backend:8000"
|
oidc_issuer = "http://openrag-backend:8000"
|
||||||
|
|
||||||
|
|
@ -109,14 +114,14 @@ class SessionManager:
|
||||||
token_payload = {
|
token_payload = {
|
||||||
# OIDC standard claims
|
# OIDC standard claims
|
||||||
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
||||||
"sub": user_id, # Subject (user ID)
|
"sub": user.user_id, # Subject (user ID)
|
||||||
"aud": ["opensearch", "openrag"], # Audience
|
"aud": ["opensearch", "openrag"], # Audience
|
||||||
"exp": now + timedelta(days=7), # Expiration
|
"exp": now + timedelta(days=7), # Expiration
|
||||||
"iat": now, # Issued at
|
"iat": now, # Issued at
|
||||||
"auth_time": int(now.timestamp()), # Authentication time
|
"auth_time": int(now.timestamp()), # Authentication time
|
||||||
|
|
||||||
# Custom claims
|
# Custom claims
|
||||||
"user_id": user_id, # Keep for backward compatibility
|
"user_id": user.user_id, # Keep for backward compatibility
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"preferred_username": user.email,
|
"preferred_username": user.email,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue