diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index 2d0d08bd..05cc85c9 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -1,6 +1,5 @@ import json import uuid -import asyncio import aiofiles from typing import Dict, List, Any, Optional from datetime import datetime @@ -62,6 +61,9 @@ class ConnectionManager: config = ConnectionConfig(**conn_data) self.connections[config.connection_id] = config + # Now that connections are loaded, clean up duplicates + await self.cleanup_duplicate_connections(remove_duplicates=True) + async def save_connections(self): """Save connections to persistent storage""" data = {"connections": []} @@ -78,33 +80,72 @@ class ConnectionManager: async with aiofiles.open(self.connections_file, "w") as f: await f.write(json.dumps(data, indent=2)) - async def create_connection( - self, - connector_type: str, - name: str, - config: Dict[str, Any], - user_id: Optional[str] = None, - ) -> str: - """Create a new connection configuration""" - connection_id = str(uuid.uuid4()) - - connection_config = ConnectionConfig( - connection_id=connection_id, - connector_type=connector_type, - name=name, - config=config, - user_id=user_id, - ) - - self.connections[connection_id] = connection_config - await self.save_connections() - - return connection_id - - async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]: - """Get connection configuration""" - return self.connections.get(connection_id) + async def _get_existing_connection( + self, connector_type: str, user_id: Optional[str] = None + ) -> Optional[ConnectionConfig]: + """Find existing active connection for the same connector type and user""" + for connection in self.connections.values(): + if ( + connection.connector_type == connector_type + and connection.user_id == user_id + and connection.is_active + ): + return connection + return None + async def cleanup_duplicate_connections(self, remove_duplicates=False): + """ + Clean up duplicate connections, keeping only the most recent connection + per provider per user + + Args: + remove_duplicates: If True, physically removes duplicates from connections.json + If False (default), just deactivates them + """ + logger.info("Starting cleanup of duplicate connections") + + # Group connections by (connector_type, user_id) + grouped_connections = {} + + for connection_id, connection in self.connections.items(): + if not connection.is_active: + continue # Skip inactive connections + + key = (connection.connector_type, connection.user_id) + + if key not in grouped_connections: + grouped_connections[key] = [] + + grouped_connections[key].append((connection_id, connection)) + + # For each group, keep only the most recent connection + connections_to_remove = [] + + for (connector_type, user_id), connections in grouped_connections.items(): + if len(connections) <= 1: + continue # No duplicates + + logger.info(f"Found {len(connections)} duplicate connections for {connector_type}, user {user_id}") + + # Sort by created_at, keep the most recent + connections.sort(key=lambda x: x[1].created_at, reverse=True) + + # Keep the first (most recent), remove/deactivate the rest + for connection_id, connection in connections[1:]: + connections_to_remove.append((connection_id, connection)) + logger.info(f"Marking connection {connection_id} for {'removal' if remove_duplicates else 'deactivation'}") + + # Remove or deactivate duplicate connections + for connection_id, connection in connections_to_remove: + if remove_duplicates: + await self.delete_connection(connection_id) # Handles token cleanup + else: + await self.deactivate_connection(connection_id) + + action = "Removed" if remove_duplicates else "Deactivated" + logger.info(f"Cleanup complete. {action} {len(connections_to_remove)} duplicate connections") + return len(connections_to_remove) + async def update_connection( self, connection_id: str, @@ -146,6 +187,61 @@ class ConnectionManager: return True + async def create_connection( + self, + connector_type: str, + name: str, + config: Dict[str, Any], + user_id: Optional[str] = None, + ) -> str: + """Create a new connection configuration, ensuring only one per provider per user""" + + # Check if we already have an active connection for this provider and user + existing_connection = await self._get_existing_connection(connector_type, user_id) + + if existing_connection: + # Check if the existing connection has a valid token + try: + connector = self._create_connector(existing_connection) + if await connector.authenticate(): + logger.info( + f"Using existing valid connection for {connector_type}", + connection_id=existing_connection.connection_id + ) + # Update the existing connection with new config if needed + if config != existing_connection.config: + logger.info("Updating existing connection config") + await self.update_connection( + existing_connection.connection_id, + config=config + ) + return existing_connection.connection_id + except Exception as e: + logger.warning( + f"Existing connection authentication failed: {e}", + connection_id=existing_connection.connection_id + ) + # If authentication fails, we'll create a new connection and clean up the old one + + # Create new connection + connection_id = str(uuid.uuid4()) + + connection_config = ConnectionConfig( + connection_id=connection_id, + connector_type=connector_type, + name=name, + config=config, + user_id=user_id, + ) + + self.connections[connection_id] = connection_config + + # Clean up duplicates (will keep the newest, which is the one we just created) + await self.cleanup_duplicate_connections(remove_duplicates=True) + + await self.save_connections() + return connection_id + async def list_connections( self, user_id: Optional[str] = None, connector_type: Optional[str] = None ) -> List[ConnectionConfig]: @@ -165,6 +261,18 @@ class ConnectionManager: if connection_id not in self.connections: return False + connection = self.connections[connection_id] + + # Clean up token file if it exists + if connection.config.get("token_file"): + token_file = Path(connection.config["token_file"]) + if token_file.exists(): + try: + token_file.unlink() + logger.info(f"Deleted token file: {token_file}") + except Exception as e: + logger.warning(f"Failed to delete token file {token_file}: {e}") + # Clean up active connector if exists if connection_id in self.active_connectors: connector = self.active_connectors[connection_id] @@ -296,6 +404,10 @@ class ConnectionManager: return True return False + + async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]: + """Get connection configuration""" + return self.connections.get(connection_id) async def get_connection_by_webhook_id( self, webhook_id: str