webhooks work
This commit is contained in:
parent
e498f9416a
commit
93b72a19be
8 changed files with 376 additions and 36 deletions
|
|
@ -81,3 +81,123 @@ async def connector_status(request: Request, connector_service, session_manager)
|
||||||
for conn in connections
|
for conn in connections
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async def connector_webhook(request: Request, connector_service, session_manager):
|
||||||
|
"""Handle webhook notifications from any connector type"""
|
||||||
|
connector_type = request.path_params.get("connector_type")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the raw payload and headers
|
||||||
|
payload = {}
|
||||||
|
headers = dict(request.headers)
|
||||||
|
|
||||||
|
if request.method == "POST":
|
||||||
|
content_type = headers.get('content-type', '').lower()
|
||||||
|
if 'application/json' in content_type:
|
||||||
|
payload = await request.json()
|
||||||
|
else:
|
||||||
|
# Some webhooks send form data or plain text
|
||||||
|
body = await request.body()
|
||||||
|
payload = {"raw_body": body.decode('utf-8') if body else ""}
|
||||||
|
else:
|
||||||
|
# GET webhooks use query params
|
||||||
|
payload = dict(request.query_params)
|
||||||
|
|
||||||
|
# Add headers to payload for connector processing
|
||||||
|
payload["_headers"] = headers
|
||||||
|
payload["_method"] = request.method
|
||||||
|
|
||||||
|
print(f"[WEBHOOK] {connector_type} notification received")
|
||||||
|
|
||||||
|
# Extract channel/subscription ID from headers (Google Drive specific)
|
||||||
|
channel_id = headers.get('x-goog-channel-id')
|
||||||
|
if not channel_id:
|
||||||
|
print(f"[WEBHOOK] No channel ID found in {connector_type} webhook")
|
||||||
|
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
|
||||||
|
|
||||||
|
# Find the specific connection for this webhook
|
||||||
|
connection = await connector_service.connection_manager.get_connection_by_webhook_id(channel_id)
|
||||||
|
if not connection or not connection.is_active:
|
||||||
|
print(f"[WEBHOOK] Unknown channel {channel_id} - attempting to cancel old subscription")
|
||||||
|
|
||||||
|
# Try to cancel this unknown subscription using any active connection of this connector type
|
||||||
|
try:
|
||||||
|
all_connections = await connector_service.connection_manager.list_connections(
|
||||||
|
connector_type=connector_type
|
||||||
|
)
|
||||||
|
active_connections = [c for c in all_connections if c.is_active]
|
||||||
|
|
||||||
|
if active_connections:
|
||||||
|
# Use the first active connection to cancel the unknown subscription
|
||||||
|
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
||||||
|
if connector:
|
||||||
|
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
||||||
|
resource_id = headers.get('x-goog-resource-id')
|
||||||
|
await connector.cleanup_subscription(channel_id, resource_id)
|
||||||
|
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] Failed to cancel unknown subscription {channel_id}: {e}")
|
||||||
|
|
||||||
|
return JSONResponse({"status": "cancelled_unknown", "channel_id": channel_id})
|
||||||
|
|
||||||
|
# Process webhook for the specific connection
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
# Get the connector instance
|
||||||
|
connector = await connector_service._get_connector(connection.connection_id)
|
||||||
|
if not connector:
|
||||||
|
print(f"[WEBHOOK] Could not get connector for connection {connection.connection_id}")
|
||||||
|
return JSONResponse({"status": "error", "reason": "connector_not_found"})
|
||||||
|
|
||||||
|
# Let the connector handle the webhook and return affected file IDs
|
||||||
|
affected_files = await connector.handle_webhook(payload)
|
||||||
|
|
||||||
|
if affected_files:
|
||||||
|
print(f"[WEBHOOK] Connection {connection.connection_id}: {len(affected_files)} files affected")
|
||||||
|
|
||||||
|
# Trigger incremental sync for affected files
|
||||||
|
task_id = await connector_service.sync_specific_files(
|
||||||
|
connection.connection_id,
|
||||||
|
connection.user_id,
|
||||||
|
affected_files
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"connection_id": connection.connection_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"affected_files": len(affected_files)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# No specific files identified - just log the webhook
|
||||||
|
print(f"[WEBHOOK] Connection {connection.connection_id}: general change detected, no specific files to sync")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"connection_id": connection.connection_id,
|
||||||
|
"action": "logged_only",
|
||||||
|
"reason": "no_specific_files"
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"status": "processed",
|
||||||
|
"connector_type": connector_type,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
**result
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to process webhook for connection {connection.connection_id}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return JSONResponse({
|
||||||
|
"status": "error",
|
||||||
|
"connector_type": connector_type,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"error": str(e)
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(f"[ERROR] Webhook processing failed: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return JSONResponse({"error": f"Webhook processing failed: {str(e)}"}, status_code=500)
|
||||||
|
|
@ -21,6 +21,9 @@ SESSION_SECRET = os.getenv("SESSION_SECRET", "your-secret-key-change-in-producti
|
||||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||||
|
|
||||||
|
# Webhook configuration - must be set to enable webhooks
|
||||||
|
WEBHOOK_BASE_URL = os.getenv("WEBHOOK_BASE_URL") # No default - must be explicitly configured
|
||||||
|
|
||||||
# OpenSearch configuration
|
# OpenSearch configuration
|
||||||
INDEX_NAME = "documents"
|
INDEX_NAME = "documents"
|
||||||
VECTOR_DIM = 1536
|
VECTOR_DIM = 1536
|
||||||
|
|
@ -96,6 +99,9 @@ class AppClients:
|
||||||
# Initialize patched OpenAI client
|
# Initialize patched OpenAI client
|
||||||
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
|
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
|
||||||
|
|
||||||
|
# Initialize document converter
|
||||||
|
self.converter = DocumentConverter()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# Global clients instance
|
# Global clients instance
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,15 @@ class ConnectionManager:
|
||||||
|
|
||||||
connection = self.connections[connection_id]
|
connection = self.connections[connection_id]
|
||||||
|
|
||||||
|
# Check if this update is adding authentication and webhooks are configured
|
||||||
|
should_setup_webhook = (
|
||||||
|
config is not None and
|
||||||
|
config.get('token_file') and
|
||||||
|
config.get('webhook_url') and # Only if webhook URL is configured
|
||||||
|
not connection.config.get('webhook_channel_id') and
|
||||||
|
connection.is_active
|
||||||
|
)
|
||||||
|
|
||||||
# Update fields if provided
|
# Update fields if provided
|
||||||
if connector_type is not None:
|
if connector_type is not None:
|
||||||
connection.connector_type = connector_type
|
connection.connector_type = connector_type
|
||||||
|
|
@ -110,6 +119,11 @@ class ConnectionManager:
|
||||||
connection.user_id = user_id
|
connection.user_id = user_id
|
||||||
|
|
||||||
await self.save_connections()
|
await self.save_connections()
|
||||||
|
|
||||||
|
# Setup webhook subscription if this connection just got authenticated with webhook URL
|
||||||
|
if should_setup_webhook:
|
||||||
|
await self._setup_webhook_for_new_connection(connection_id, connection)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def list_connections(self, user_id: Optional[str] = None, connector_type: Optional[str] = None) -> List[ConnectionConfig]:
|
async def list_connections(self, user_id: Optional[str] = None, connector_type: Optional[str] = None) -> List[ConnectionConfig]:
|
||||||
|
|
@ -164,6 +178,10 @@ class ConnectionManager:
|
||||||
connector = self._create_connector(connection_config)
|
connector = self._create_connector(connection_config)
|
||||||
if await connector.authenticate():
|
if await connector.authenticate():
|
||||||
self.active_connectors[connection_id] = connector
|
self.active_connectors[connection_id] = connector
|
||||||
|
|
||||||
|
# Setup webhook subscription if not already set up
|
||||||
|
await self._setup_webhook_if_needed(connection_id, connection_config, connector)
|
||||||
|
|
||||||
return connector
|
return connector
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
@ -207,3 +225,71 @@ class ConnectionManager:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def get_connection_by_webhook_id(self, webhook_id: str) -> Optional[ConnectionConfig]:
|
||||||
|
"""Find a connection by its webhook/subscription ID"""
|
||||||
|
for connection in self.connections.values():
|
||||||
|
# Check if the webhook ID is stored in the connection config
|
||||||
|
if connection.config.get('webhook_channel_id') == webhook_id:
|
||||||
|
return connection
|
||||||
|
# Also check for subscription_id (alternative field name)
|
||||||
|
if connection.config.get('subscription_id') == webhook_id:
|
||||||
|
return connection
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _setup_webhook_if_needed(self, connection_id: str, connection_config: ConnectionConfig, connector: BaseConnector):
|
||||||
|
"""Setup webhook subscription if not already configured"""
|
||||||
|
# Check if webhook is already set up
|
||||||
|
if connection_config.config.get('webhook_channel_id') or connection_config.config.get('subscription_id'):
|
||||||
|
print(f"[WEBHOOK] Subscription already exists for connection {connection_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if webhook URL is configured
|
||||||
|
webhook_url = connection_config.config.get('webhook_url')
|
||||||
|
if not webhook_url:
|
||||||
|
print(f"[WEBHOOK] No webhook URL configured for connection {connection_id}, skipping subscription setup")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"[WEBHOOK] Setting up subscription for connection {connection_id}")
|
||||||
|
subscription_id = await connector.setup_subscription()
|
||||||
|
|
||||||
|
# Store the subscription ID in connection config
|
||||||
|
connection_config.config['webhook_channel_id'] = subscription_id
|
||||||
|
connection_config.config['subscription_id'] = subscription_id # Alternative field
|
||||||
|
|
||||||
|
# Save updated connection config
|
||||||
|
await self.save_connections()
|
||||||
|
|
||||||
|
print(f"[WEBHOOK] Successfully set up subscription {subscription_id} for connection {connection_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to setup webhook subscription for connection {connection_id}: {e}")
|
||||||
|
# Don't fail the entire connection setup if webhook fails
|
||||||
|
|
||||||
|
async def _setup_webhook_for_new_connection(self, connection_id: str, connection_config: ConnectionConfig):
|
||||||
|
"""Setup webhook subscription for a newly authenticated connection"""
|
||||||
|
try:
|
||||||
|
print(f"[WEBHOOK] Setting up subscription for newly authenticated connection {connection_id}")
|
||||||
|
|
||||||
|
# Create and authenticate connector
|
||||||
|
connector = self._create_connector(connection_config)
|
||||||
|
if not await connector.authenticate():
|
||||||
|
print(f"[ERROR] Failed to authenticate connector for webhook setup: {connection_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup subscription
|
||||||
|
subscription_id = await connector.setup_subscription()
|
||||||
|
|
||||||
|
# Store the subscription ID in connection config
|
||||||
|
connection_config.config['webhook_channel_id'] = subscription_id
|
||||||
|
connection_config.config['subscription_id'] = subscription_id
|
||||||
|
|
||||||
|
# Save updated connection config
|
||||||
|
await self.save_connections()
|
||||||
|
|
||||||
|
print(f"[WEBHOOK] Successfully set up subscription {subscription_id} for connection {connection_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to setup webhook subscription for new connection {connection_id}: {e}")
|
||||||
|
# Don't fail the connection setup if webhook fails
|
||||||
|
|
@ -152,7 +152,8 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
token_file=config.get('token_file', 'gdrive_token.json')
|
token_file=config.get('token_file', 'gdrive_token.json')
|
||||||
)
|
)
|
||||||
self.service = None
|
self.service = None
|
||||||
self.webhook_channel_id = None
|
# Load existing webhook channel ID from config if available
|
||||||
|
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
|
||||||
|
|
||||||
async def authenticate(self) -> bool:
|
async def authenticate(self) -> bool:
|
||||||
"""Authenticate with Google Drive"""
|
"""Authenticate with Google Drive"""
|
||||||
|
|
@ -224,7 +225,7 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Use the same process pool as docling processing
|
# Use the same process pool as docling processing
|
||||||
from app import process_pool
|
from utils.process_pool import process_pool
|
||||||
results = await loop.run_in_executor(
|
results = await loop.run_in_executor(
|
||||||
process_pool,
|
process_pool,
|
||||||
_sync_list_files_worker,
|
_sync_list_files_worker,
|
||||||
|
|
@ -268,7 +269,7 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Use the same process pool as docling processing
|
# Use the same process pool as docling processing
|
||||||
from app import process_pool
|
from utils.process_pool import process_pool
|
||||||
file_metadata = await loop.run_in_executor(
|
file_metadata = await loop.run_in_executor(
|
||||||
process_pool,
|
process_pool,
|
||||||
_sync_get_metadata_worker,
|
_sync_get_metadata_worker,
|
||||||
|
|
@ -313,7 +314,7 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Use the same process pool as docling processing
|
# Use the same process pool as docling processing
|
||||||
from app import process_pool
|
from utils.process_pool import process_pool
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
process_pool,
|
process_pool,
|
||||||
_sync_download_worker,
|
_sync_download_worker,
|
||||||
|
|
@ -359,46 +360,92 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
if not self._authenticated:
|
if not self._authenticated:
|
||||||
raise ValueError("Not authenticated")
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
# Google Drive sends change notifications
|
# Google Drive sends headers with the important info
|
||||||
# We need to query for actual changes
|
headers = payload.get('_headers', {})
|
||||||
|
|
||||||
|
# Extract Google Drive specific headers
|
||||||
|
channel_id = headers.get('x-goog-channel-id')
|
||||||
|
resource_state = headers.get('x-goog-resource-state')
|
||||||
|
|
||||||
|
if not channel_id:
|
||||||
|
print("[WEBHOOK] No channel ID found in Google Drive webhook")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check if this webhook belongs to this connection
|
||||||
|
if self.webhook_channel_id != channel_id:
|
||||||
|
print(f"[WEBHOOK] Channel ID mismatch: expected {self.webhook_channel_id}, got {channel_id}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Only process certain states (ignore 'sync' which is just a ping)
|
||||||
|
if resource_state not in ['exists', 'not_exists', 'change']:
|
||||||
|
print(f"[WEBHOOK] Ignoring resource state: {resource_state}")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
page_token = payload.get('pageToken')
|
# Extract page token from the resource URI if available
|
||||||
|
page_token = None
|
||||||
|
headers = payload.get('_headers', {})
|
||||||
|
resource_uri = headers.get('x-goog-resource-uri')
|
||||||
|
|
||||||
|
if resource_uri and 'pageToken=' in resource_uri:
|
||||||
|
# Extract page token from URI like:
|
||||||
|
# https://www.googleapis.com/drive/v3/changes?alt=json&pageToken=4337807
|
||||||
|
import urllib.parse
|
||||||
|
parsed = urllib.parse.urlparse(resource_uri)
|
||||||
|
query_params = urllib.parse.parse_qs(parsed.query)
|
||||||
|
page_token = query_params.get('pageToken', [None])[0]
|
||||||
|
|
||||||
if not page_token:
|
if not page_token:
|
||||||
# Get current page token and return empty list
|
print("[WEBHOOK] No page token found, cannot identify specific changes")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Get list of changes
|
print(f"[WEBHOOK] Getting changes since page token: {page_token}")
|
||||||
|
|
||||||
|
# Get list of changes since the page token
|
||||||
changes = self.service.changes().list(
|
changes = self.service.changes().list(
|
||||||
pageToken=page_token,
|
pageToken=page_token,
|
||||||
fields="changes(fileId, file(id, name, mimeType, trashed))"
|
fields="changes(fileId, file(id, name, mimeType, trashed, parents))"
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
affected_files = []
|
affected_files = []
|
||||||
for change in changes.get('changes', []):
|
for change in changes.get('changes', []):
|
||||||
file_info = change.get('file', {})
|
file_info = change.get('file', {})
|
||||||
# Only include supported file types that aren't trashed
|
file_id = change.get('fileId')
|
||||||
if (file_info.get('mimeType') in self.SUPPORTED_MIMETYPES and
|
|
||||||
not file_info.get('trashed', False)):
|
|
||||||
affected_files.append(change['fileId'])
|
|
||||||
|
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only include supported file types that aren't trashed
|
||||||
|
mime_type = file_info.get('mimeType', '')
|
||||||
|
is_trashed = file_info.get('trashed', False)
|
||||||
|
|
||||||
|
if not is_trashed and mime_type in self.SUPPORTED_MIMETYPES:
|
||||||
|
print(f"[WEBHOOK] File changed: {file_info.get('name', 'Unknown')} ({file_id})")
|
||||||
|
affected_files.append(file_id)
|
||||||
|
elif is_trashed:
|
||||||
|
print(f"[WEBHOOK] File deleted/trashed: {file_info.get('name', 'Unknown')} ({file_id})")
|
||||||
|
# TODO: Handle file deletion (remove from index)
|
||||||
|
else:
|
||||||
|
print(f"[WEBHOOK] Ignoring unsupported file type: {mime_type}")
|
||||||
|
|
||||||
|
print(f"[WEBHOOK] Found {len(affected_files)} affected supported files")
|
||||||
return affected_files
|
return affected_files
|
||||||
|
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
print(f"Failed to handle webhook: {e}")
|
print(f"Failed to handle webhook: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||||
"""Clean up Google Drive subscription"""
|
"""Clean up Google Drive subscription"""
|
||||||
if not self._authenticated:
|
if not self._authenticated:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.service.channels().stop(
|
body = {'id': subscription_id}
|
||||||
body={
|
if resource_id:
|
||||||
'id': subscription_id,
|
body['resourceId'] = resource_id
|
||||||
'resourceId': subscription_id # This might need adjustment based on Google's response
|
|
||||||
}
|
self.service.channels().stop(body=body).execute()
|
||||||
).execute()
|
|
||||||
return True
|
return True
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
print(f"Failed to cleanup subscription: {e}")
|
print(f"Failed to cleanup subscription: {e}")
|
||||||
|
|
|
||||||
|
|
@ -37,11 +37,13 @@ class ConnectorService:
|
||||||
tmp_file.flush()
|
tmp_file.flush()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use existing process_file_common function from app.py with connector document ID
|
# Use existing process_file_common function with connector document metadata
|
||||||
from app import process_file_common
|
# We'll use the document service's process_file_common method
|
||||||
|
from services.document_service import DocumentService
|
||||||
|
doc_service = DocumentService()
|
||||||
|
|
||||||
# Process using the existing pipeline but with connector document metadata
|
# Process using the existing pipeline but with connector document metadata
|
||||||
result = await process_file_common(
|
result = await doc_service.process_file_common(
|
||||||
file_path=tmp_file.name,
|
file_path=tmp_file.name,
|
||||||
file_hash=document.id, # Use connector document ID as hash
|
file_hash=document.id, # Use connector document ID as hash
|
||||||
owner_user_id=owner_user_id
|
owner_user_id=owner_user_id
|
||||||
|
|
@ -170,3 +172,32 @@ class ConnectorService:
|
||||||
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
||||||
|
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
async def sync_specific_files(self, connection_id: str, user_id: str, file_ids: List[str]) -> str:
|
||||||
|
"""Sync specific files by their IDs (used for webhook-triggered syncs)"""
|
||||||
|
if not self.task_service:
|
||||||
|
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||||
|
|
||||||
|
connector = await self.get_connector(connection_id)
|
||||||
|
if not connector:
|
||||||
|
raise ValueError(f"Connection '{connection_id}' not found or not authenticated")
|
||||||
|
|
||||||
|
if not connector.is_authenticated:
|
||||||
|
raise ValueError(f"Connection '{connection_id}' not authenticated")
|
||||||
|
|
||||||
|
if not file_ids:
|
||||||
|
raise ValueError("No file IDs provided")
|
||||||
|
|
||||||
|
# Create custom processor for specific connector files
|
||||||
|
from models.processors import ConnectorFileProcessor
|
||||||
|
# We'll pass file_ids as the files_info, the processor will handle ID-only files
|
||||||
|
processor = ConnectorFileProcessor(self, connection_id, file_ids)
|
||||||
|
|
||||||
|
# Create custom task using TaskService
|
||||||
|
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def _get_connector(self, connection_id: str) -> Optional[BaseConnector]:
|
||||||
|
"""Get a connector by connection ID (alias for get_connector)"""
|
||||||
|
return await self.get_connector(connection_id)
|
||||||
43
src/main.py
43
src/main.py
|
|
@ -217,11 +217,22 @@ def create_app():
|
||||||
connector_service=services['connector_service'],
|
connector_service=services['connector_service'],
|
||||||
session_manager=services['session_manager'])
|
session_manager=services['session_manager'])
|
||||||
), methods=["GET"]),
|
), methods=["GET"]),
|
||||||
|
|
||||||
|
Route("/connectors/{connector_type}/webhook",
|
||||||
|
partial(connectors.connector_webhook,
|
||||||
|
connector_service=services['connector_service'],
|
||||||
|
session_manager=services['session_manager']),
|
||||||
|
methods=["POST", "GET"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
app = Starlette(debug=True, routes=routes)
|
app = Starlette(debug=True, routes=routes)
|
||||||
app.state.services = services # Store services for cleanup
|
app.state.services = services # Store services for cleanup
|
||||||
|
|
||||||
|
# Add shutdown event handler
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
await cleanup_subscriptions_proper(services)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
async def startup():
|
async def startup():
|
||||||
|
|
@ -233,9 +244,39 @@ async def startup():
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
"""Cleanup on application shutdown"""
|
"""Cleanup on application shutdown"""
|
||||||
# This will be called on exit to cleanup process pools
|
# Cleanup process pools only (webhooks handled by Starlette shutdown)
|
||||||
|
print("[CLEANUP] Shutting down...")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def cleanup_subscriptions_proper(services):
|
||||||
|
"""Cancel all active webhook subscriptions"""
|
||||||
|
print("[CLEANUP] Cancelling active webhook subscriptions...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector_service = services['connector_service']
|
||||||
|
await connector_service.connection_manager.load_connections()
|
||||||
|
|
||||||
|
# Get all active connections with webhook subscriptions
|
||||||
|
all_connections = await connector_service.connection_manager.list_connections()
|
||||||
|
active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')]
|
||||||
|
|
||||||
|
for connection in active_connections:
|
||||||
|
try:
|
||||||
|
print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}")
|
||||||
|
connector = await connector_service.get_connector(connection.connection_id)
|
||||||
|
if connector:
|
||||||
|
subscription_id = connection.config.get('webhook_channel_id')
|
||||||
|
resource_id = connection.config.get('resource_id') # If stored
|
||||||
|
await connector.cleanup_subscription(subscription_id, resource_id)
|
||||||
|
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
|
||||||
|
|
||||||
|
print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to cleanup subscriptions: {e}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,15 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
self.connector_service = connector_service
|
self.connector_service = connector_service
|
||||||
self.connection_id = connection_id
|
self.connection_id = connection_id
|
||||||
self.files_to_process = files_to_process
|
self.files_to_process = files_to_process
|
||||||
# Create lookup map for file info
|
# Create lookup map for file info - handle both file objects and file IDs
|
||||||
self.file_info_map = {f['id']: f for f in files_to_process}
|
self.file_info_map = {}
|
||||||
|
for f in files_to_process:
|
||||||
|
if isinstance(f, dict):
|
||||||
|
# Full file info objects
|
||||||
|
self.file_info_map[f['id']] = f
|
||||||
|
else:
|
||||||
|
# Just file IDs - will need to fetch metadata during processing
|
||||||
|
self.file_info_map[f] = None
|
||||||
|
|
||||||
async def process_item(self, upload_task: UploadTask, item: str, file_task: FileTask) -> None:
|
async def process_item(self, upload_task: UploadTask, item: str, file_task: FileTask) -> None:
|
||||||
"""Process a connector file using ConnectorService"""
|
"""Process a connector file using ConnectorService"""
|
||||||
|
|
@ -49,16 +56,13 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
file_id = item # item is the connector file ID
|
file_id = item # item is the connector file ID
|
||||||
file_info = self.file_info_map.get(file_id)
|
file_info = self.file_info_map.get(file_id)
|
||||||
|
|
||||||
if not file_info:
|
|
||||||
raise ValueError(f"File info not found for {file_id}")
|
|
||||||
|
|
||||||
# Get the connector
|
# Get the connector
|
||||||
connector = await self.connector_service.get_connector(self.connection_id)
|
connector = await self.connector_service.get_connector(self.connection_id)
|
||||||
if not connector:
|
if not connector:
|
||||||
raise ValueError(f"Connection '{self.connection_id}' not found")
|
raise ValueError(f"Connection '{self.connection_id}' not found")
|
||||||
|
|
||||||
# Get file content from connector
|
# Get file content from connector (the connector will fetch metadata if needed)
|
||||||
document = await connector.get_file_content(file_info['id'])
|
document = await connector.get_file_content(file_id)
|
||||||
|
|
||||||
# Get user_id from task store lookup
|
# Get user_id from task store lookup
|
||||||
user_id = None
|
user_id = None
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import aiofiles
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET
|
from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET, WEBHOOK_BASE_URL
|
||||||
from session_manager import SessionManager
|
from session_manager import SessionManager
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
|
|
@ -37,6 +37,10 @@ class AuthService:
|
||||||
"redirect_uri": redirect_uri
|
"redirect_uri": redirect_uri
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Only add webhook URL if WEBHOOK_BASE_URL is configured
|
||||||
|
if WEBHOOK_BASE_URL:
|
||||||
|
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{provider}_drive/webhook"
|
||||||
|
|
||||||
# Create connection in manager
|
# Create connection in manager
|
||||||
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
|
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
|
||||||
connection_id = await self.connector_service.connection_manager.create_connection(
|
connection_id = await self.connector_service.connection_manager.create_connection(
|
||||||
|
|
@ -167,7 +171,8 @@ class AuthService:
|
||||||
config={
|
config={
|
||||||
**connection_config.config,
|
**connection_config.config,
|
||||||
"purpose": "data_source",
|
"purpose": "data_source",
|
||||||
"user_email": user_info.get("email")
|
"user_email": user_info.get("email"),
|
||||||
|
**({"webhook_url": f"{WEBHOOK_BASE_URL}/connectors/google_drive/webhook"} if WEBHOOK_BASE_URL else {})
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response_data["google_drive_connection_id"] = connection_id
|
response_data["google_drive_connection_id"] = connection_id
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue