openrag/src/connectors/connection_manager.py
2025-12-12 15:13:11 -05:00

534 lines
21 KiB
Python

import json
import uuid
import aiofiles
from typing import Dict, List, Any, Optional
from datetime import datetime
from dataclasses import dataclass, asdict
from pathlib import Path
from utils.logging_config import get_logger
logger = get_logger(__name__)
from .base import BaseConnector
from .google_drive import GoogleDriveConnector
from .sharepoint import SharePointConnector
from .onedrive import OneDriveConnector
@dataclass
class ConnectionConfig:
"""Configuration for a connector connection"""
connection_id: str
connector_type: str # "google_drive", "box", etc.
name: str # User-friendly name
config: Dict[str, Any] # Connector-specific config
user_id: Optional[str] = None # For multi-tenant support
created_at: datetime = None
last_sync: Optional[datetime] = None
is_active: bool = True
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
class ConnectionManager:
"""Manages multiple connector connections with persistence"""
def __init__(self, connections_file: str = "data/connections.json"):
self.connections_file = Path(connections_file)
# Ensure data directory exists
self.connections_file.parent.mkdir(parents=True, exist_ok=True)
self.connections: Dict[str, ConnectionConfig] = {}
self.active_connectors: Dict[str, BaseConnector] = {}
async def load_connections(self):
"""Load connections from persistent storage"""
if self.connections_file.exists():
async with aiofiles.open(self.connections_file, "r") as f:
data = json.loads(await f.read())
for conn_data in data.get("connections", []):
# Convert datetime strings back to datetime objects
if conn_data.get("created_at"):
conn_data["created_at"] = datetime.fromisoformat(
conn_data["created_at"]
)
if conn_data.get("last_sync"):
conn_data["last_sync"] = datetime.fromisoformat(
conn_data["last_sync"]
)
config = ConnectionConfig(**conn_data)
self.connections[config.connection_id] = config
# Now that connections are loaded, clean up duplicates
await self.cleanup_duplicate_connections(remove_duplicates=True)
async def save_connections(self):
"""Save connections to persistent storage"""
data = {"connections": []}
for config in self.connections.values():
conn_data = asdict(config)
# Convert datetime objects to strings
if conn_data.get("created_at"):
conn_data["created_at"] = conn_data["created_at"].isoformat()
if conn_data.get("last_sync"):
conn_data["last_sync"] = conn_data["last_sync"].isoformat()
data["connections"].append(conn_data)
async with aiofiles.open(self.connections_file, "w") as f:
await f.write(json.dumps(data, indent=2))
async def _get_existing_connection(
self, connector_type: str, user_id: Optional[str] = None
) -> Optional[ConnectionConfig]:
"""Find existing active connection for the same connector type and user"""
for connection in self.connections.values():
if (
connection.connector_type == connector_type
and connection.user_id == user_id
and connection.is_active
):
return connection
return None
async def cleanup_duplicate_connections(self, remove_duplicates=False):
"""
Clean up duplicate connections, keeping only the most recent connection
per provider per user
Args:
remove_duplicates: If True, physically removes duplicates from connections.json
If False (default), just deactivates them
"""
logger.info("Starting cleanup of duplicate connections")
# Group connections by (connector_type, user_id)
grouped_connections = {}
for connection_id, connection in self.connections.items():
if not connection.is_active:
continue # Skip inactive connections
key = (connection.connector_type, connection.user_id)
if key not in grouped_connections:
grouped_connections[key] = []
grouped_connections[key].append((connection_id, connection))
# For each group, keep only the most recent connection
connections_to_remove = []
for (connector_type, user_id), connections in grouped_connections.items():
if len(connections) <= 1:
continue # No duplicates
logger.info(f"Found {len(connections)} duplicate connections for {connector_type}, user {user_id}")
# Sort by created_at, keep the most recent
connections.sort(key=lambda x: x[1].created_at, reverse=True)
# Keep the first (most recent), remove/deactivate the rest
for connection_id, connection in connections[1:]:
connections_to_remove.append((connection_id, connection))
logger.info(f"Marking connection {connection_id} for {'removal' if remove_duplicates else 'deactivation'}")
# Remove or deactivate duplicate connections
for connection_id, connection in connections_to_remove:
if remove_duplicates:
await self.delete_connection(connection_id) # Handles token cleanup
else:
await self.deactivate_connection(connection_id)
action = "Removed" if remove_duplicates else "Deactivated"
logger.info(f"Cleanup complete. {action} {len(connections_to_remove)} duplicate connections")
return len(connections_to_remove)
async def update_connection(
self,
connection_id: str,
connector_type: str = None,
name: str = None,
config: Dict[str, Any] = None,
user_id: str = None,
) -> bool:
"""Update an existing connection configuration"""
if connection_id not in self.connections:
return False
connection = self.connections[connection_id]
# Check if this update is adding authentication and webhooks are configured
should_setup_webhook = (
config is not None
and config.get("token_file")
and config.get("webhook_url") # Only if webhook URL is configured
and not connection.config.get("webhook_channel_id")
and connection.is_active
)
# Update fields if provided
if connector_type is not None:
connection.connector_type = connector_type
if name is not None:
connection.name = name
if config is not None:
connection.config = config
if user_id is not None:
connection.user_id = user_id
await self.save_connections()
# Setup webhook subscription if this connection just got authenticated with webhook URL
if should_setup_webhook:
await self._setup_webhook_for_new_connection(connection_id, connection)
return True
async def create_connection(
self,
connector_type: str,
name: str,
config: Dict[str, Any],
user_id: Optional[str] = None,
) -> str:
"""Create a new connection configuration, ensuring only one per provider per user"""
# Check if we already have an active connection for this provider and user
existing_connection = await self._get_existing_connection(connector_type, user_id)
if existing_connection:
# Check if the existing connection has a valid token
try:
connector = self._create_connector(existing_connection)
if await connector.authenticate():
logger.info(
f"Using existing valid connection for {connector_type}",
connection_id=existing_connection.connection_id
)
# Update the existing connection with new config if needed
if config != existing_connection.config:
logger.info("Updating existing connection config")
await self.update_connection(
existing_connection.connection_id,
config=config
)
return existing_connection.connection_id
except Exception as e:
logger.warning(
f"Existing connection authentication failed: {e}",
connection_id=existing_connection.connection_id
)
# If authentication fails, we'll create a new connection and clean up the old one
# Create new connection
connection_id = str(uuid.uuid4())
connection_config = ConnectionConfig(
connection_id=connection_id,
connector_type=connector_type,
name=name,
config=config,
user_id=user_id,
)
self.connections[connection_id] = connection_config
# Clean up duplicates (will keep the newest, which is the one we just created)
await self.cleanup_duplicate_connections(remove_duplicates=True)
await self.save_connections()
return connection_id
async def list_connections(
self, user_id: Optional[str] = None, connector_type: Optional[str] = None
) -> List[ConnectionConfig]:
"""List connections, optionally filtered by user or connector type"""
connections = list(self.connections.values())
if user_id is not None:
connections = [c for c in connections if c.user_id == user_id]
if connector_type is not None:
connections = [c for c in connections if c.connector_type == connector_type]
return connections
async def delete_connection(self, connection_id: str) -> bool:
"""Delete a connection"""
if connection_id not in self.connections:
return False
connection = self.connections[connection_id]
# Clean up token file if it exists
if connection.config.get("token_file"):
token_file = Path(connection.config["token_file"])
if token_file.exists():
try:
token_file.unlink()
logger.info(f"Deleted token file: {token_file}")
except Exception as e:
logger.warning(f"Failed to delete token file {token_file}: {e}")
# Clean up active connector if exists
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
# Try to cleanup subscriptions if applicable
try:
if (
hasattr(connector, "webhook_channel_id")
and connector.webhook_channel_id
):
await connector.cleanup_subscription(connector.webhook_channel_id)
except:
pass # Best effort cleanup
del self.active_connectors[connection_id]
del self.connections[connection_id]
await self.save_connections()
return True
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance"""
logger.debug(f"Getting connector for connection_id: {connection_id}")
# Return cached connector if available
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
if connector.is_authenticated:
logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector
else:
# Remove unauthenticated connector from cache
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id]
# Try to create and authenticate connector
connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active:
logger.debug(f"No active connection config found for {connection_id}")
return None
logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config)
logger.debug(f"Attempting authentication for {connection_id}")
auth_result = await connector.authenticate()
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
if auth_result:
self.active_connectors[connection_id] = connector
# ... rest of the method
return connector
else:
logger.warning(f"Authentication failed for {connection_id}")
return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
"""Get available connector types with their metadata"""
return {
"google_drive": {
"name": GoogleDriveConnector.CONNECTOR_NAME,
"description": GoogleDriveConnector.CONNECTOR_DESCRIPTION,
"icon": GoogleDriveConnector.CONNECTOR_ICON,
"available": self._is_connector_available("google_drive"),
},
"sharepoint": {
"name": SharePointConnector.CONNECTOR_NAME,
"description": SharePointConnector.CONNECTOR_DESCRIPTION,
"icon": SharePointConnector.CONNECTOR_ICON,
"available": self._is_connector_available("sharepoint"),
},
"onedrive": {
"name": OneDriveConnector.CONNECTOR_NAME,
"description": OneDriveConnector.CONNECTOR_DESCRIPTION,
"icon": OneDriveConnector.CONNECTOR_ICON,
"available": self._is_connector_available("onedrive"),
},
}
def _is_connector_available(self, connector_type: str) -> bool:
"""Check if a connector type is available (has required env vars)"""
try:
temp_config = ConnectionConfig(
connection_id="temp",
connector_type=connector_type,
name="temp",
config={},
)
connector = self._create_connector(temp_config)
# Try to get credentials to check if env vars are set
connector.get_client_id()
connector.get_client_secret()
return True
except (ValueError, NotImplementedError):
return False
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances"""
try:
if config.connector_type == "google_drive":
return GoogleDriveConnector(config.config)
elif config.connector_type == "sharepoint":
return SharePointConnector(config.config)
elif config.connector_type == "onedrive":
return OneDriveConnector(config.config)
elif config.connector_type == "box":
raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox":
raise NotImplementedError("Dropbox connector not implemented yet")
else:
raise ValueError(f"Unknown connector type: {config.connector_type}")
except Exception as e:
logger.error(f"Failed to create {config.connector_type} connector: {e}")
# Re-raise the exception so caller can handle appropriately
raise
async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection"""
if connection_id in self.connections:
self.connections[connection_id].last_sync = datetime.now()
await self.save_connections()
async def activate_connection(self, connection_id: str) -> bool:
"""Activate a connection"""
if connection_id in self.connections:
self.connections[connection_id].is_active = True
await self.save_connections()
return True
return False
async def deactivate_connection(self, connection_id: str) -> bool:
"""Deactivate a connection"""
if connection_id in self.connections:
self.connections[connection_id].is_active = False
await self.save_connections()
# Remove from active connectors
if connection_id in self.active_connectors:
del self.active_connectors[connection_id]
return True
return False
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
"""Get connection configuration"""
return self.connections.get(connection_id)
async def get_connection_by_webhook_id(
self, webhook_id: str
) -> Optional[ConnectionConfig]:
"""Find a connection by its webhook/subscription ID"""
for connection in self.connections.values():
# Check if the webhook ID is stored in the connection config
if connection.config.get("webhook_channel_id") == webhook_id:
return connection
# Also check for subscription_id (alternative field name)
if connection.config.get("subscription_id") == webhook_id:
return connection
return None
async def _setup_webhook_if_needed(
self,
connection_id: str,
connection_config: ConnectionConfig,
connector: BaseConnector,
):
"""Setup webhook subscription if not already configured"""
# Check if webhook is already set up
if connection_config.config.get(
"webhook_channel_id"
) or connection_config.config.get("subscription_id"):
logger.info(
"Webhook subscription already exists", connection_id=connection_id
)
return
# Check if webhook URL is configured
webhook_url = connection_config.config.get("webhook_url")
if not webhook_url:
logger.info(
"No webhook URL configured, skipping subscription setup",
connection_id=connection_id,
)
return
try:
logger.info("Setting up webhook subscription", connection_id=connection_id)
subscription_id = await connector.setup_subscription()
# Store the subscription and resource IDs in connection config
connection_config.config["webhook_channel_id"] = subscription_id
connection_config.config["subscription_id"] = (
subscription_id # Alternative field
)
if getattr(connector, "webhook_resource_id", None):
connection_config.config["resource_id"] = connector.webhook_resource_id
# Save updated connection config
await self.save_connections()
logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e:
logger.error(
"Failed to setup webhook subscription",
connection_id=connection_id,
error=str(e),
)
# Don't fail the entire connection setup if webhook fails
async def _setup_webhook_for_new_connection(
self, connection_id: str, connection_config: ConnectionConfig
):
"""Setup webhook subscription for a newly authenticated connection"""
try:
logger.info(
"Setting up subscription for newly authenticated connection",
connection_id=connection_id,
)
# Create and authenticate connector
connector = self._create_connector(connection_config)
if not await connector.authenticate():
logger.error(
"Failed to authenticate connector for webhook setup",
connection_id=connection_id,
)
return
# Setup subscription
subscription_id = await connector.setup_subscription()
# Store the subscription and resource IDs in connection config
connection_config.config["webhook_channel_id"] = subscription_id
connection_config.config["subscription_id"] = subscription_id
if getattr(connector, "webhook_resource_id", None):
connection_config.config["resource_id"] = connector.webhook_resource_id
# Save updated connection config
await self.save_connections()
logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e:
logger.error(
"Failed to setup webhook subscription for new connection",
connection_id=connection_id,
error=str(e),
)
# Don't fail the connection setup if webhook fails