Merge branch 'main' of github.com:phact/gendb

This commit is contained in:
phact 2025-09-08 16:12:43 -04:00
commit a4e271cb37

View file

@ -1,6 +1,5 @@
import json
import uuid
import asyncio
import aiofiles
from typing import Dict, List, Any, Optional
from datetime import datetime
@ -62,6 +61,9 @@ class ConnectionManager:
config = ConnectionConfig(**conn_data)
self.connections[config.connection_id] = config
# Now that connections are loaded, clean up duplicates
await self.cleanup_duplicate_connections(remove_duplicates=True)
async def save_connections(self):
"""Save connections to persistent storage"""
data = {"connections": []}
@ -78,33 +80,72 @@ class ConnectionManager:
async with aiofiles.open(self.connections_file, "w") as f:
await f.write(json.dumps(data, indent=2))
async def create_connection(
self,
connector_type: str,
name: str,
config: Dict[str, Any],
user_id: Optional[str] = None,
) -> str:
"""Create a new connection configuration"""
connection_id = str(uuid.uuid4())
connection_config = ConnectionConfig(
connection_id=connection_id,
connector_type=connector_type,
name=name,
config=config,
user_id=user_id,
)
self.connections[connection_id] = connection_config
await self.save_connections()
return connection_id
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
"""Get connection configuration"""
return self.connections.get(connection_id)
async def _get_existing_connection(
self, connector_type: str, user_id: Optional[str] = None
) -> Optional[ConnectionConfig]:
"""Find existing active connection for the same connector type and user"""
for connection in self.connections.values():
if (
connection.connector_type == connector_type
and connection.user_id == user_id
and connection.is_active
):
return connection
return None
async def cleanup_duplicate_connections(self, remove_duplicates=False):
"""
Clean up duplicate connections, keeping only the most recent connection
per provider per user
Args:
remove_duplicates: If True, physically removes duplicates from connections.json
If False (default), just deactivates them
"""
logger.info("Starting cleanup of duplicate connections")
# Group connections by (connector_type, user_id)
grouped_connections = {}
for connection_id, connection in self.connections.items():
if not connection.is_active:
continue # Skip inactive connections
key = (connection.connector_type, connection.user_id)
if key not in grouped_connections:
grouped_connections[key] = []
grouped_connections[key].append((connection_id, connection))
# For each group, keep only the most recent connection
connections_to_remove = []
for (connector_type, user_id), connections in grouped_connections.items():
if len(connections) <= 1:
continue # No duplicates
logger.info(f"Found {len(connections)} duplicate connections for {connector_type}, user {user_id}")
# Sort by created_at, keep the most recent
connections.sort(key=lambda x: x[1].created_at, reverse=True)
# Keep the first (most recent), remove/deactivate the rest
for connection_id, connection in connections[1:]:
connections_to_remove.append((connection_id, connection))
logger.info(f"Marking connection {connection_id} for {'removal' if remove_duplicates else 'deactivation'}")
# Remove or deactivate duplicate connections
for connection_id, connection in connections_to_remove:
if remove_duplicates:
await self.delete_connection(connection_id) # Handles token cleanup
else:
await self.deactivate_connection(connection_id)
action = "Removed" if remove_duplicates else "Deactivated"
logger.info(f"Cleanup complete. {action} {len(connections_to_remove)} duplicate connections")
return len(connections_to_remove)
async def update_connection(
self,
connection_id: str,
@ -146,6 +187,61 @@ class ConnectionManager:
return True
async def create_connection(
self,
connector_type: str,
name: str,
config: Dict[str, Any],
user_id: Optional[str] = None,
) -> str:
"""Create a new connection configuration, ensuring only one per provider per user"""
# Check if we already have an active connection for this provider and user
existing_connection = await self._get_existing_connection(connector_type, user_id)
if existing_connection:
# Check if the existing connection has a valid token
try:
connector = self._create_connector(existing_connection)
if await connector.authenticate():
logger.info(
f"Using existing valid connection for {connector_type}",
connection_id=existing_connection.connection_id
)
# Update the existing connection with new config if needed
if config != existing_connection.config:
logger.info("Updating existing connection config")
await self.update_connection(
existing_connection.connection_id,
config=config
)
return existing_connection.connection_id
except Exception as e:
logger.warning(
f"Existing connection authentication failed: {e}",
connection_id=existing_connection.connection_id
)
# If authentication fails, we'll create a new connection and clean up the old one
# Create new connection
connection_id = str(uuid.uuid4())
connection_config = ConnectionConfig(
connection_id=connection_id,
connector_type=connector_type,
name=name,
config=config,
user_id=user_id,
)
self.connections[connection_id] = connection_config
# Clean up duplicates (will keep the newest, which is the one we just created)
await self.cleanup_duplicate_connections(remove_duplicates=True)
await self.save_connections()
return connection_id
async def list_connections(
self, user_id: Optional[str] = None, connector_type: Optional[str] = None
) -> List[ConnectionConfig]:
@ -165,6 +261,18 @@ class ConnectionManager:
if connection_id not in self.connections:
return False
connection = self.connections[connection_id]
# Clean up token file if it exists
if connection.config.get("token_file"):
token_file = Path(connection.config["token_file"])
if token_file.exists():
try:
token_file.unlink()
logger.info(f"Deleted token file: {token_file}")
except Exception as e:
logger.warning(f"Failed to delete token file {token_file}: {e}")
# Clean up active connector if exists
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
@ -296,6 +404,10 @@ class ConnectionManager:
return True
return False
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
"""Get connection configuration"""
return self.connections.get(connection_id)
async def get_connection_by_webhook_id(
self, webhook_id: str