Merge branch 'main' of github.com:phact/gendb
This commit is contained in:
commit
a4e271cb37
1 changed files with 139 additions and 27 deletions
|
|
@ -1,6 +1,5 @@
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -62,6 +61,9 @@ class ConnectionManager:
|
||||||
config = ConnectionConfig(**conn_data)
|
config = ConnectionConfig(**conn_data)
|
||||||
self.connections[config.connection_id] = config
|
self.connections[config.connection_id] = config
|
||||||
|
|
||||||
|
# Now that connections are loaded, clean up duplicates
|
||||||
|
await self.cleanup_duplicate_connections(remove_duplicates=True)
|
||||||
|
|
||||||
async def save_connections(self):
|
async def save_connections(self):
|
||||||
"""Save connections to persistent storage"""
|
"""Save connections to persistent storage"""
|
||||||
data = {"connections": []}
|
data = {"connections": []}
|
||||||
|
|
@ -78,32 +80,71 @@ class ConnectionManager:
|
||||||
async with aiofiles.open(self.connections_file, "w") as f:
|
async with aiofiles.open(self.connections_file, "w") as f:
|
||||||
await f.write(json.dumps(data, indent=2))
|
await f.write(json.dumps(data, indent=2))
|
||||||
|
|
||||||
async def create_connection(
|
async def _get_existing_connection(
|
||||||
self,
|
self, connector_type: str, user_id: Optional[str] = None
|
||||||
connector_type: str,
|
) -> Optional[ConnectionConfig]:
|
||||||
name: str,
|
"""Find existing active connection for the same connector type and user"""
|
||||||
config: Dict[str, Any],
|
for connection in self.connections.values():
|
||||||
user_id: Optional[str] = None,
|
if (
|
||||||
) -> str:
|
connection.connector_type == connector_type
|
||||||
"""Create a new connection configuration"""
|
and connection.user_id == user_id
|
||||||
connection_id = str(uuid.uuid4())
|
and connection.is_active
|
||||||
|
):
|
||||||
|
return connection
|
||||||
|
return None
|
||||||
|
|
||||||
connection_config = ConnectionConfig(
|
async def cleanup_duplicate_connections(self, remove_duplicates=False):
|
||||||
connection_id=connection_id,
|
"""
|
||||||
connector_type=connector_type,
|
Clean up duplicate connections, keeping only the most recent connection
|
||||||
name=name,
|
per provider per user
|
||||||
config=config,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.connections[connection_id] = connection_config
|
Args:
|
||||||
await self.save_connections()
|
remove_duplicates: If True, physically removes duplicates from connections.json
|
||||||
|
If False (default), just deactivates them
|
||||||
|
"""
|
||||||
|
logger.info("Starting cleanup of duplicate connections")
|
||||||
|
|
||||||
return connection_id
|
# Group connections by (connector_type, user_id)
|
||||||
|
grouped_connections = {}
|
||||||
|
|
||||||
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
|
for connection_id, connection in self.connections.items():
|
||||||
"""Get connection configuration"""
|
if not connection.is_active:
|
||||||
return self.connections.get(connection_id)
|
continue # Skip inactive connections
|
||||||
|
|
||||||
|
key = (connection.connector_type, connection.user_id)
|
||||||
|
|
||||||
|
if key not in grouped_connections:
|
||||||
|
grouped_connections[key] = []
|
||||||
|
|
||||||
|
grouped_connections[key].append((connection_id, connection))
|
||||||
|
|
||||||
|
# For each group, keep only the most recent connection
|
||||||
|
connections_to_remove = []
|
||||||
|
|
||||||
|
for (connector_type, user_id), connections in grouped_connections.items():
|
||||||
|
if len(connections) <= 1:
|
||||||
|
continue # No duplicates
|
||||||
|
|
||||||
|
logger.info(f"Found {len(connections)} duplicate connections for {connector_type}, user {user_id}")
|
||||||
|
|
||||||
|
# Sort by created_at, keep the most recent
|
||||||
|
connections.sort(key=lambda x: x[1].created_at, reverse=True)
|
||||||
|
|
||||||
|
# Keep the first (most recent), remove/deactivate the rest
|
||||||
|
for connection_id, connection in connections[1:]:
|
||||||
|
connections_to_remove.append((connection_id, connection))
|
||||||
|
logger.info(f"Marking connection {connection_id} for {'removal' if remove_duplicates else 'deactivation'}")
|
||||||
|
|
||||||
|
# Remove or deactivate duplicate connections
|
||||||
|
for connection_id, connection in connections_to_remove:
|
||||||
|
if remove_duplicates:
|
||||||
|
await self.delete_connection(connection_id) # Handles token cleanup
|
||||||
|
else:
|
||||||
|
await self.deactivate_connection(connection_id)
|
||||||
|
|
||||||
|
action = "Removed" if remove_duplicates else "Deactivated"
|
||||||
|
logger.info(f"Cleanup complete. {action} {len(connections_to_remove)} duplicate connections")
|
||||||
|
return len(connections_to_remove)
|
||||||
|
|
||||||
async def update_connection(
|
async def update_connection(
|
||||||
self,
|
self,
|
||||||
|
|
@ -146,6 +187,61 @@ class ConnectionManager:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def create_connection(
|
||||||
|
self,
|
||||||
|
connector_type: str,
|
||||||
|
name: str,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new connection configuration, ensuring only one per provider per user"""
|
||||||
|
|
||||||
|
# Check if we already have an active connection for this provider and user
|
||||||
|
existing_connection = await self._get_existing_connection(connector_type, user_id)
|
||||||
|
|
||||||
|
if existing_connection:
|
||||||
|
# Check if the existing connection has a valid token
|
||||||
|
try:
|
||||||
|
connector = self._create_connector(existing_connection)
|
||||||
|
if await connector.authenticate():
|
||||||
|
logger.info(
|
||||||
|
f"Using existing valid connection for {connector_type}",
|
||||||
|
connection_id=existing_connection.connection_id
|
||||||
|
)
|
||||||
|
# Update the existing connection with new config if needed
|
||||||
|
if config != existing_connection.config:
|
||||||
|
logger.info("Updating existing connection config")
|
||||||
|
await self.update_connection(
|
||||||
|
existing_connection.connection_id,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
return existing_connection.connection_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Existing connection authentication failed: {e}",
|
||||||
|
connection_id=existing_connection.connection_id
|
||||||
|
)
|
||||||
|
# If authentication fails, we'll create a new connection and clean up the old one
|
||||||
|
|
||||||
|
# Create new connection
|
||||||
|
connection_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
connection_config = ConnectionConfig(
|
||||||
|
connection_id=connection_id,
|
||||||
|
connector_type=connector_type,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.connections[connection_id] = connection_config
|
||||||
|
|
||||||
|
# Clean up duplicates (will keep the newest, which is the one we just created)
|
||||||
|
await self.cleanup_duplicate_connections(remove_duplicates=True)
|
||||||
|
|
||||||
|
await self.save_connections()
|
||||||
|
return connection_id
|
||||||
|
|
||||||
async def list_connections(
|
async def list_connections(
|
||||||
self, user_id: Optional[str] = None, connector_type: Optional[str] = None
|
self, user_id: Optional[str] = None, connector_type: Optional[str] = None
|
||||||
) -> List[ConnectionConfig]:
|
) -> List[ConnectionConfig]:
|
||||||
|
|
@ -165,6 +261,18 @@ class ConnectionManager:
|
||||||
if connection_id not in self.connections:
|
if connection_id not in self.connections:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
connection = self.connections[connection_id]
|
||||||
|
|
||||||
|
# Clean up token file if it exists
|
||||||
|
if connection.config.get("token_file"):
|
||||||
|
token_file = Path(connection.config["token_file"])
|
||||||
|
if token_file.exists():
|
||||||
|
try:
|
||||||
|
token_file.unlink()
|
||||||
|
logger.info(f"Deleted token file: {token_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete token file {token_file}: {e}")
|
||||||
|
|
||||||
# Clean up active connector if exists
|
# Clean up active connector if exists
|
||||||
if connection_id in self.active_connectors:
|
if connection_id in self.active_connectors:
|
||||||
connector = self.active_connectors[connection_id]
|
connector = self.active_connectors[connection_id]
|
||||||
|
|
@ -297,6 +405,10 @@ class ConnectionManager:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def get_connection(self, connection_id: str) -> Optional[ConnectionConfig]:
|
||||||
|
"""Get connection configuration"""
|
||||||
|
return self.connections.get(connection_id)
|
||||||
|
|
||||||
async def get_connection_by_webhook_id(
|
async def get_connection_by_webhook_id(
|
||||||
self, webhook_id: str
|
self, webhook_id: str
|
||||||
) -> Optional[ConnectionConfig]:
|
) -> Optional[ConnectionConfig]:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue