diff --git a/age_adapter/adapter.py b/age_adapter/adapter.py new file mode 100644 index 000000000..c5405e828 --- /dev/null +++ b/age_adapter/adapter.py @@ -0,0 +1,687 @@ +""" +Apache AGE Adapter for Graph Database Operations + +This module provides a Python async adapter for Apache AGE (A Graph Extension), +a PostgreSQL extension that adds graph database capabilities with Cypher support. +""" + +import json +import asyncio +from typing import Optional, Any, List, Dict, Tuple +from dataclasses import dataclass + +try: + import asyncpg + ASYNCPG_AVAILABLE = True +except ImportError: + ASYNCPG_AVAILABLE = False + + +@dataclass +class NodeData: + """Represents a graph node.""" + id: str + properties: Dict[str, Any] + + +@dataclass +class EdgeData: + """Represents a graph edge.""" + source_id: str + target_id: str + relationship_type: str + properties: Dict[str, Any] + + +class ApacheAGEAdapter: + """ + Async adapter for Apache AGE graph database operations. + + Provides methods for: + - Node CRUD operations + - Edge CRUD operations + - Cypher query execution + - Graph traversal and metrics + """ + + def __init__( + self, + host: str = "localhost", + port: int = 5432, + username: str = "postgres", + password: str = "password", + database: str = "cognee", + graph_name: str = "cognee_graph", + ): + """ + Initialize the AGE adapter. + + Args: + host: PostgreSQL host + port: PostgreSQL port + username: PostgreSQL username + password: PostgreSQL password + database: PostgreSQL database name + graph_name: AGE graph name (schema) + """ + if not ASYNCPG_AVAILABLE: + raise ImportError( + "asyncpg is required. Install with: pip install asyncpg" + ) + + self.host = host + self.port = port + self.username = username + self.password = password + self.database = database + self.graph_name = graph_name + self.pool: Optional[asyncpg.Pool] = None + + async def connect(self) -> None: + """ + Create connection pool and initialize AGE. + + Automatically creates the graph if it doesn't exist. + """ + if self.pool is None: + # Connection initialization callback + async def init_connection(conn): + """Initialize each connection in the pool with AGE settings.""" + await conn.execute("LOAD 'age';") + await conn.execute("SET search_path = ag_catalog, '$user', public;") + + self.pool = await asyncpg.create_pool( + host=self.host, + port=self.port, + user=self.username, + password=self.password, + database=self.database, + min_size=2, + max_size=10, + init=init_connection # Initialize each connection + ) + + # Initialize AGE extension (only once) + async with self.pool.acquire() as conn: + await conn.execute("CREATE EXTENSION IF NOT EXISTS age;") + + # Create graph if it doesn't exist + await self.create_graph_if_not_exists() + + # Create index on id for faster MERGE operations + try: + async with self.pool.acquire() as conn: + await conn.execute(f""" + SELECT create_vlabel('{self.graph_name}', 'Node'); + CREATE INDEX IF NOT EXISTS idx_node_id ON {self.graph_name}.Node(id); + """) + except: + pass + + async def create_graph_if_not_exists(self) -> bool: + """ + Create the graph if it doesn't exist. + + Returns: + True if graph was created, False if it already existed + """ + async with self.pool.acquire() as conn: + try: + await conn.execute(f"SELECT create_graph('{self.graph_name}');") + print(f"✓ Created graph: {self.graph_name}") + return True + except Exception as e: + if "already exists" in str(e).lower(): + print(f"✓ Graph '{self.graph_name}' already exists") + return False + else: + raise + + async def close(self) -> None: + """Close connection pool.""" + if self.pool: + await self.pool.close() + self.pool = None + + async def execute_cypher( + self, + query: str, + params: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + Execute a Cypher query. + + Args: + query: Cypher query string + params: Query parameters + + Returns: + List of result dictionaries + """ + if self.pool is None: + await self.connect() + + # Wrap Cypher query in AGE SQL syntax + if params: + param_str = json.dumps(params) + wrapped_query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + {query} + $$, '{param_str}') as (result agtype); + """ + else: + wrapped_query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + {query} + $$) as (result agtype); + """ + + async with self.pool.acquire() as conn: + # Connections are already initialized with AGE settings in the pool init callback + rows = await conn.fetch(wrapped_query) + + # Parse AGE's agtype results + results = [] + for row in rows: + if row['result']: + # AGE returns agtype, convert to Python dict + result_str = str(row['result']) + if result_str and result_str != 'null': + try: + # Remove AGE-specific type annotations if present + # AGE can return things like ::vertex or ::edge + if '::' in result_str: + result_str = result_str.split('::')[0] + + result_data = json.loads(result_str) + results.append(result_data) + except json.JSONDecodeError: + # If JSON parsing fails, try to extract just the value + results.append({'value': result_str}) + + return results + + async def add_node( + self, + node_id: str, + labels: List[str], + properties: Dict[str, Any] + ) -> NodeData: + """ + Add a node to the graph. Uses MERGE to avoid duplicates. + If a node with the same id already exists, it will be updated with new properties. + + Args: + node_id: Unique node identifier + labels: Node labels (types) + properties: Node properties + + Returns: + Created/updated node data + """ + props = {**properties, 'id': node_id} + + # Build property string manually for AGE compatibility + props_parts = [] + for k, v in props.items(): + if isinstance(v, str): + props_parts.append(f"{k}: '{v}'") + elif isinstance(v, bool): + props_parts.append(f"{k}: {str(v).lower()}") + elif isinstance(v, (int, float)): + props_parts.append(f"{k}: {v}") + elif v is None: + props_parts.append(f"{k}: null") + else: + # For complex types, convert to JSON string + props_parts.append(f"{k}: '{json.dumps(v)}'") + + props_str = ', '.join(props_parts) + label_str = ':'.join(labels) if labels else 'Node' + + # Use MERGE to avoid duplicates - matches on id property + query = f""" + MERGE (n:{label_str} {{id: '{node_id}'}}) + SET n = {{{props_str}}} + RETURN n + """ + + results = await self.execute_cypher(query) + + return NodeData( + id=node_id, + properties=props + ) + + async def add_nodes(self, nodes: List[Tuple[str, List[str], Dict[str, Any]]]) -> None: + """ + Add multiple nodes in a single batch operation using UNWIND. + This is significantly faster than calling add_node() multiple times. + Uses MERGE to avoid duplicates. + + Args: + nodes: List of tuples (node_id, labels, properties) + + Example: + nodes = [ + ("user_1", ["User"], {"name": "Alice", "age": 30}), + ("user_2", ["User"], {"name": "Bob", "age": 25}), + ] + await adapter.add_nodes(nodes) + """ + if not nodes: + return + + # Process in batches of 100 + BATCH_SIZE = 100 + + for i in range(0, len(nodes), BATCH_SIZE): + batch = nodes[i:i + BATCH_SIZE] + + node_data_list = [] + for node_id, labels, properties in batch: + props = {"id": node_id, **properties} + props_parts = [] + for k, v in props.items(): + if isinstance(v, str): + props_parts.append(f'{k}: "{v}"') + elif isinstance(v, bool): + props_parts.append(f'{k}: {str(v).lower()}') + elif isinstance(v, (int, float)): + props_parts.append(f'{k}: {v}') + elif v is None: + props_parts.append(f'{k}: null') + else: + props_parts.append(f'{k}: "{json.dumps(v)}"') + props_str = '{' + ', '.join(props_parts) + '}' + label_str = ':'.join(labels) if labels else "Node" + node_data_list.append(f'{{id: "{node_id}", props: {props_str}, label: "{label_str}"}}') + + unwind_data = '[' + ', '.join(node_data_list) + ']' + + all_prop_keys = set() + for node_id, labels, properties in batch: + all_prop_keys.update(properties.keys()) + all_prop_keys.add('id') + + set_clauses = [f"n.{key} = node_data.props.{key}" for key in sorted(all_prop_keys)] + set_clause = "SET " + ", ".join(set_clauses) + + common_label = batch[0][1][0] if batch[0][1] else "Node" + query = f""" + UNWIND {unwind_data} AS node_data + MERGE (n {{id: node_data.id}}) + {set_clause} + """ + await self.execute_cypher(query) + + async def get_node(self, node_id: str) -> Optional[NodeData]: + """ + Get a node by ID. + + Args: + node_id: Node identifier + + Returns: + Node data or None if not found + """ + query = f"MATCH (n {{id: '{node_id}'}}) RETURN n" + results = await self.execute_cypher(query) + + if results: + node_data = results[0] + # Extract properties from AGE vertex structure + if isinstance(node_data, dict) and 'properties' in node_data: + props = node_data['properties'] + else: + props = node_data + + return NodeData( + id=props.get('id', node_id), + properties=props + ) + return None + + async def delete_node(self, node_id: str) -> bool: + """ + Delete a node by ID. + + Args: + node_id: Node identifier + + Returns: + True if deleted, False if not found + """ + query = f"MATCH (n {{id: '{node_id}'}}) DETACH DELETE n" + await self.execute_cypher(query) + return True + + async def add_edge( + self, + source_id: str, + target_id: str, + relationship_type: str, + properties: Optional[Dict[str, Any]] = None + ) -> EdgeData: + """ + Add an edge between two nodes. Uses MERGE to avoid duplicates. + If an edge already exists between the same nodes with the same type, + it will be updated with new properties. + + Args: + source_id: Source node ID + target_id: Target node ID + relationship_type: Relationship type/name + properties: Edge properties + + Returns: + Created/updated edge data + """ + props = properties or {} + + # Build property string manually for AGE compatibility + if props: + props_parts = [] + for k, v in props.items(): + if isinstance(v, str): + props_parts.append(f"{k}: '{v}'") + elif isinstance(v, bool): + props_parts.append(f"{k}: {str(v).lower()}") + elif isinstance(v, (int, float)): + props_parts.append(f"{k}: {v}") + elif v is None: + props_parts.append(f"{k}: null") + else: + # For complex types, convert to JSON string + props_parts.append(f"{k}: '{json.dumps(v)}'") + props_str = ', '.join(props_parts) + + # Use MERGE to avoid duplicate edges + query = f""" + MATCH (a {{id: '{source_id}'}}), (b {{id: '{target_id}'}}) + MERGE (a)-[r:{relationship_type}]->(b) + SET r = {{{props_str}}} + RETURN r + """ + else: + # Use MERGE without properties + query = f""" + MATCH (a {{id: '{source_id}'}}), (b {{id: '{target_id}'}}) + MERGE (a)-[r:{relationship_type}]->(b) + RETURN r + """ + + await self.execute_cypher(query) + + return EdgeData( + source_id=source_id, + target_id=target_id, + relationship_type=relationship_type, + properties=props + ) + + async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None: + """ + Add multiple edges in a single batch operation using UNWIND. + This is significantly faster than calling add_edge() multiple times. + Uses MERGE to avoid duplicates. + + Args: + edges: List of tuples (source_id, target_id, relationship_type, properties) + + Example: + edges = [ + ("user_1", "user_2", "KNOWS", {"since": 2020}), + ("user_2", "user_3", "FOLLOWS", {"weight": 0.5}), + ] + await adapter.add_edges(edges) + """ + if not edges: + return + + # Group edges by relationship type for efficiency + edges_by_type = {} + for source_id, target_id, rel_type, properties in edges: + if rel_type not in edges_by_type: + edges_by_type[rel_type] = [] + edges_by_type[rel_type].append({ + "source_id": source_id, + "target_id": target_id, + "properties": properties or {} + }) + + # Process each relationship type in batches + BATCH_SIZE = 100 # Smaller batches to avoid huge query strings + + for rel_type, edge_list in edges_by_type.items(): + # Get all unique property keys for this relationship type + all_prop_keys = set() + for edge in edge_list: + all_prop_keys.update(edge["properties"].keys()) + + # Process in batches + for i in range(0, len(edge_list), BATCH_SIZE): + batch = edge_list[i:i + BATCH_SIZE] + + # Build VALUES clause for batch MERGE + values_parts = [] + for edge in batch: + props = edge["properties"] + # Build property map for this edge + props_cypher_parts = [] + for key in all_prop_keys: + value = props.get(key) + if value is None: + props_cypher_parts.append(f'{key}: null') + elif isinstance(value, str): + # Escape quotes in strings + escaped = value.replace('"', '\\"') + props_cypher_parts.append(f'{key}: "{escaped}"') + elif isinstance(value, bool): + props_cypher_parts.append(f'{key}: {str(value).lower()}') + elif isinstance(value, (int, float)): + props_cypher_parts.append(f'{key}: {value}') + props_str = ', '.join(props_cypher_parts) + + values_parts.append(f'{{src: "{edge["source_id"]}", tgt: "{edge["target_id"]}", props: {{{props_str}}}}}') + + # Build UNWIND query (AGE requires inline data, not parameters) + values_list = '[' + ', '.join(values_parts) + ']' + + # Build SET clause with explicit assignments (AGE doesn't support SET r = map) + if all_prop_keys: + set_parts = [f'r.{key} = edge.props.{key}' for key in all_prop_keys] + set_clause = 'SET ' + ', '.join(set_parts) + else: + set_clause = '' + + query = f""" + UNWIND {values_list} AS edge + MATCH (a {{id: edge.src}}), (b {{id: edge.tgt}}) + MERGE (a)-[r:{rel_type}]->(b) + {set_clause} + """ + + await self.execute_cypher(query) + + async def get_edges(self, node_id: str) -> List[EdgeData]: + """ + Get all edges connected to a node. + + Args: + node_id: Node identifier + + Returns: + List of edge data + """ + query = f""" + MATCH (a {{id: '{node_id}'}})-[r]-(b) + RETURN {{source: a.id, target: b.id, rel_type: type(r), props: properties(r)}} + """ + + results = await self.execute_cypher(query) + + edges = [] + for result in results: + # Result is directly the edge data map + edges.append(EdgeData( + source_id=result.get('source', ''), + target_id=result.get('target', ''), + relationship_type=result.get('rel_type', ''), + properties=result.get('props', {}) + )) + + return edges + + async def get_neighbors(self, node_id: str) -> List[NodeData]: + """ + Get all neighboring nodes. + + Args: + node_id: Node identifier + + Returns: + List of neighbor nodes + """ + # Use simple map return instead of full vertex object for better performance + query = f""" + MATCH (n {{id: '{node_id}'}})-[]-(neighbor) + RETURN DISTINCT {{id: neighbor.id, properties: properties(neighbor)}} + """ + + results = await self.execute_cypher(query) + + neighbors = [] + for result in results: + # Result is already a simple map + neighbors.append(NodeData( + id=result.get('id', ''), + properties=result.get('properties', {}) + )) + + return neighbors + + async def count_nodes(self) -> int: + """ + Count total nodes in the graph. + + Returns: + Number of nodes + """ + query = "MATCH (n) RETURN {count: count(n)}" + results = await self.execute_cypher(query) + + if results: + return results[0].get('count', 0) + return 0 + + async def count_edges(self) -> int: + """ + Count total edges in the graph. + + Returns: + Number of edges + """ + query = "MATCH ()-[r]->() RETURN {count: count(r)}" + results = await self.execute_cypher(query) + + if results: + return results[0].get('count', 0) + return 0 + + async def clear_graph(self) -> None: + """ + Delete all nodes and edges from the graph. + + Note: This only removes the data. The graph schema and tables remain. + Use drop_graph() to completely remove the graph including its tables. + """ + query = "MATCH (n) DETACH DELETE n" + await self.execute_cypher(query) + + async def drop_graph(self, recreate: bool = False) -> None: + """ + Completely drop the graph including its schema and tables from PostgreSQL. + + Args: + recreate: If True, recreates an empty graph immediately after dropping + + Warning: This permanently removes all graph data, schema, and tables. + """ + async with self.pool.acquire() as conn: + await conn.execute("SET search_path = ag_catalog, '$user', public;") + await conn.execute("LOAD 'age';") + + try: + await conn.execute(f"SELECT drop_graph('{self.graph_name}', true);") + except asyncpg.exceptions.UndefinedObjectError: + # Graph doesn't exist, nothing to drop + pass + except Exception as e: + raise Exception(f"Error dropping graph '{self.graph_name}': {e}") from e + + # Recreate if requested + if recreate: + await self.create_graph_if_not_exists() + + async def list_all_graphs(self) -> List[str]: + """ + List all Apache AGE graphs in the database. + + Returns: + List of graph names + """ + async with self.pool.acquire() as conn: + await conn.execute("SET search_path = ag_catalog, '$user', public;") + await conn.execute("LOAD 'age';") + + # Query ag_catalog.ag_graph to get all graphs + result = await conn.fetch( + "SELECT name FROM ag_catalog.ag_graph ORDER BY name;" + ) + + return [row['name'] for row in result] + + async def drop_all_graphs(self) -> List[str]: + """ + Drop ALL Apache AGE graphs in the database. + + Returns: + List of dropped graph names + + Warning: This permanently removes ALL graphs from the database! + """ + graphs = await self.list_all_graphs() + + async with self.pool.acquire() as conn: + await conn.execute("SET search_path = ag_catalog, '$user', public;") + await conn.execute("LOAD 'age';") + + dropped = [] + for graph_name in graphs: + try: + await conn.execute(f"SELECT drop_graph('{graph_name}', true);") + dropped.append(graph_name) + print(f"✓ Dropped graph: {graph_name}") + except Exception as e: + print(f"✗ Failed to drop graph '{graph_name}': {e}") + + return dropped + + async def get_stats(self) -> Dict[str, int]: + """ + Get graph statistics. + + Returns: + Dictionary with node and edge counts + """ + num_nodes = await self.count_nodes() + num_edges = await self.count_edges() + + return { + 'nodes': num_nodes, + 'edges': num_edges, + 'mean_degree': (2 * num_edges / num_nodes) if num_nodes > 0 else 0 + } + + diff --git a/age_adapter/benchmark.py b/age_adapter/benchmark.py new file mode 100644 index 000000000..ef0947c23 --- /dev/null +++ b/age_adapter/benchmark.py @@ -0,0 +1,135 @@ +import asyncio +import time +import sys +from pathlib import Path +from typing import List, Dict, Any, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from age_adapter.adapter import ApacheAGEAdapter +from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from uuid import UUID + + +class SimpleNode(DataPoint): + model_config = {"extra": "allow"} + + def __init__(self, node_id: str, properties: Dict[str, Any]): + try: + node_uuid = UUID(node_id) if '-' in node_id else UUID(int=hash(node_id) & ((1 << 128) - 1)) + except: + node_uuid = UUID(int=hash(node_id) & ((1 << 128) - 1)) + super().__init__(id=node_uuid, **properties) + + +async def main(): + age_adapter = ApacheAGEAdapter( + host="localhost", + port=5432, + username="cognee", + password="cognee", + database="cognee_db", + graph_name="benchmark_graph" + ) + + neo4j_adapter = Neo4jAdapter( + graph_database_url="bolt://localhost:7687", + graph_database_username="neo4j", + graph_database_password="pleaseletmein", + graph_database_name=None + ) + + await age_adapter.connect() + await neo4j_adapter.initialize() + + batch_size = 3000 + node_ids = [f"node_{i}" for i in range(batch_size)] + nodes = [(nid, ["TestNode"], {"name": f"Node {i}", "value": i}) + for i, nid in enumerate(node_ids)] + + await age_adapter.drop_graph(recreate=True) + await neo4j_adapter.delete_graph() + + start = time.perf_counter() + for node_id, labels, props in nodes: + await age_adapter.add_node(node_id, labels, props) + age_time_single_new = time.perf_counter() - start + + start = time.perf_counter() + for node_id, labels, props in nodes: + neo4j_node = SimpleNode(node_id, props) + await neo4j_adapter.add_node(neo4j_node) + neo4j_time_single_new = time.perf_counter() - start + + print(f"Node Ingestion Single (New): AGE={age_time_single_new:.4f}s, Neo4j={neo4j_time_single_new:.4f}s") + + half = batch_size // 2 + existing_nodes = nodes[:half] + new_node_ids = [f"node_{i}" for i in range(batch_size, batch_size + half)] + new_nodes = [(nid, ["TestNode"], {"name": f"Node {i}", "value": i}) + for i, nid in enumerate(new_node_ids, start=batch_size)] + merge_nodes = existing_nodes + new_nodes + + for node_id, labels, props in existing_nodes: + await age_adapter.add_node(node_id, labels, props) + neo4j_node = SimpleNode(node_id, props) + await neo4j_adapter.add_node(neo4j_node) + + start = time.perf_counter() + for node_id, labels, props in merge_nodes: + await age_adapter.add_node(node_id, labels, props) + age_time_single_merge = time.perf_counter() - start + + start = time.perf_counter() + for node_id, labels, props in merge_nodes: + neo4j_node = SimpleNode(node_id, props) + await neo4j_adapter.add_node(neo4j_node) + neo4j_time_single_merge = time.perf_counter() - start + + print(f"Node Ingestion Single (Merge - {half} existing, {len(new_nodes)} new): AGE={age_time_single_merge:.4f}s, Neo4j={neo4j_time_single_merge:.4f}s") + + await age_adapter.drop_graph(recreate=True) + await neo4j_adapter.delete_graph() + + start = time.perf_counter() + for i in range(0, len(nodes), 100): + await age_adapter.add_nodes(nodes[i:i+100]) + age_time_batch_new = time.perf_counter() - start + + start = time.perf_counter() + for i in range(0, len(nodes), 100): + batch = nodes[i:i+100] + neo4j_nodes = [SimpleNode(node_id, props) for node_id, _, props in batch] + await neo4j_adapter.add_nodes(neo4j_nodes) + neo4j_time_batch_new = time.perf_counter() - start + + print(f"Node Ingestion Batch (New): AGE={age_time_batch_new:.4f}s, Neo4j={neo4j_time_batch_new:.4f}s") + + for i in range(0, len(existing_nodes), 100): + await age_adapter.add_nodes(existing_nodes[i:i+100]) + batch = existing_nodes[i:i+100] + neo4j_existing = [SimpleNode(node_id, props) for node_id, _, props in batch] + await neo4j_adapter.add_nodes(neo4j_existing) + + start = time.perf_counter() + for i in range(0, len(merge_nodes), 100): + await age_adapter.add_nodes(merge_nodes[i:i+100]) + age_time_batch_merge = time.perf_counter() - start + + start = time.perf_counter() + for i in range(0, len(merge_nodes), 100): + batch = merge_nodes[i:i+100] + neo4j_merge_nodes = [SimpleNode(node_id, props) for node_id, _, props in batch] + await neo4j_adapter.add_nodes(neo4j_merge_nodes) + neo4j_time_batch_merge = time.perf_counter() - start + + print(f"Node Ingestion Batch (Merge - {half} existing, {len(new_nodes)} new): AGE={age_time_batch_merge:.4f}s, Neo4j={neo4j_time_batch_merge:.4f}s") + + await age_adapter.close() + await neo4j_adapter.driver.close() + + +if __name__ == "__main__": + asyncio.run(main()) +