Merge branch 'main' of github.com:phact/gendb

This commit is contained in:
phact 2025-09-08 16:12:43 -04:00
commit a4e271cb37

View file

@ -1,6 +1,5 @@
import json import json
import uuid import uuid
import asyncio
import aiofiles import aiofiles
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional
from datetime import datetime from datetime import datetime
@ -62,6 +61,9 @@ class ConnectionManager:
config = ConnectionConfig(**conn_data) config = ConnectionConfig(**conn_data)
self.connections[config.connection_id] = config self.connections[config.connection_id] = config
# Now that connections are loaded, clean up duplicates
await self.cleanup_duplicate_connections(remove_duplicates=True)
async def save_connections(self): async def save_connections(self):
"""Save connections to persistent storage""" """Save connections to persistent storage"""
data = {"connections": []} data = {"connections": []}
@ -78,33 +80,72 @@ class ConnectionManager:
async with aiofiles.open(self.connections_file, "w") as f: async with aiofiles.open(self.connections_file, "w") as f:
await f.write(json.dumps(data, indent=2)) await f.write(json.dumps(data, indent=2))
async def create_connection( async def _get_existing_connection(
self, self, connector_type: str, user_id: Optional[str] = None
connector_type: str, ) -> Optional[ConnectionConfig]:
name: str, """Find existing active connection for the same connector type and user"""
config: Dict[str, Any], for connection in self.connections.values():
user_id: Optional[str] = None, if (
) -> str: connection.connector_type == connector_type
"""Create a new connection configuration""" and connection.user_id == user_id
connection_id = str(uuid.uuid4()) and connection.is_active
):
connection_config = ConnectionConfig( return connection
connection_id=connection_id, return None
connector_type=connector_type,
name=name,
config=config,
user_id=user_id,
)
self.connections[connection_id] = connection_config
await self.save_connections()
return connection_id
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
"""Get connection configuration"""
return self.connections.get(connection_id)
async def cleanup_duplicate_connections(self, remove_duplicates=False):
"""
Clean up duplicate connections, keeping only the most recent connection
per provider per user
Args:
remove_duplicates: If True, physically removes duplicates from connections.json
If False (default), just deactivates them
"""
logger.info("Starting cleanup of duplicate connections")
# Group connections by (connector_type, user_id)
grouped_connections = {}
for connection_id, connection in self.connections.items():
if not connection.is_active:
continue # Skip inactive connections
key = (connection.connector_type, connection.user_id)
if key not in grouped_connections:
grouped_connections[key] = []
grouped_connections[key].append((connection_id, connection))
# For each group, keep only the most recent connection
connections_to_remove = []
for (connector_type, user_id), connections in grouped_connections.items():
if len(connections) <= 1:
continue # No duplicates
logger.info(f"Found {len(connections)} duplicate connections for {connector_type}, user {user_id}")
# Sort by created_at, keep the most recent
connections.sort(key=lambda x: x[1].created_at, reverse=True)
# Keep the first (most recent), remove/deactivate the rest
for connection_id, connection in connections[1:]:
connections_to_remove.append((connection_id, connection))
logger.info(f"Marking connection {connection_id} for {'removal' if remove_duplicates else 'deactivation'}")
# Remove or deactivate duplicate connections
for connection_id, connection in connections_to_remove:
if remove_duplicates:
await self.delete_connection(connection_id) # Handles token cleanup
else:
await self.deactivate_connection(connection_id)
action = "Removed" if remove_duplicates else "Deactivated"
logger.info(f"Cleanup complete. {action} {len(connections_to_remove)} duplicate connections")
return len(connections_to_remove)
async def update_connection( async def update_connection(
self, self,
connection_id: str, connection_id: str,
@ -146,6 +187,61 @@ class ConnectionManager:
return True return True
async def create_connection(
self,
connector_type: str,
name: str,
config: Dict[str, Any],
user_id: Optional[str] = None,
) -> str:
"""Create a new connection configuration, ensuring only one per provider per user"""
# Check if we already have an active connection for this provider and user
existing_connection = await self._get_existing_connection(connector_type, user_id)
if existing_connection:
# Check if the existing connection has a valid token
try:
connector = self._create_connector(existing_connection)
if await connector.authenticate():
logger.info(
f"Using existing valid connection for {connector_type}",
connection_id=existing_connection.connection_id
)
# Update the existing connection with new config if needed
if config != existing_connection.config:
logger.info("Updating existing connection config")
await self.update_connection(
existing_connection.connection_id,
config=config
)
return existing_connection.connection_id
except Exception as e:
logger.warning(
f"Existing connection authentication failed: {e}",
connection_id=existing_connection.connection_id
)
# If authentication fails, we'll create a new connection and clean up the old one
# Create new connection
connection_id = str(uuid.uuid4())
connection_config = ConnectionConfig(
connection_id=connection_id,
connector_type=connector_type,
name=name,
config=config,
user_id=user_id,
)
self.connections[connection_id] = connection_config
# Clean up duplicates (will keep the newest, which is the one we just created)
await self.cleanup_duplicate_connections(remove_duplicates=True)
await self.save_connections()
return connection_id
async def list_connections( async def list_connections(
self, user_id: Optional[str] = None, connector_type: Optional[str] = None self, user_id: Optional[str] = None, connector_type: Optional[str] = None
) -> List[ConnectionConfig]: ) -> List[ConnectionConfig]:
@ -165,6 +261,18 @@ class ConnectionManager:
if connection_id not in self.connections: if connection_id not in self.connections:
return False return False
connection = self.connections[connection_id]
# Clean up token file if it exists
if connection.config.get("token_file"):
token_file = Path(connection.config["token_file"])
if token_file.exists():
try:
token_file.unlink()
logger.info(f"Deleted token file: {token_file}")
except Exception as e:
logger.warning(f"Failed to delete token file {token_file}: {e}")
# Clean up active connector if exists # Clean up active connector if exists
if connection_id in self.active_connectors: if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id] connector = self.active_connectors[connection_id]
@ -296,6 +404,10 @@ class ConnectionManager:
return True return True
return False return False
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
"""Get connection configuration"""
return self.connections.get(connection_id)
async def get_connection_by_webhook_id( async def get_connection_by_webhook_id(
self, webhook_id: str self, webhook_id: str