test
This commit is contained in:
parent
fad1a84516
commit
c021e832e7
2 changed files with 822 additions and 0 deletions
687
age_adapter/adapter.py
Normal file
687
age_adapter/adapter.py
Normal file
|
|
@ -0,0 +1,687 @@
|
|||
"""
|
||||
Apache AGE Adapter for Graph Database Operations
|
||||
|
||||
This module provides a Python async adapter for Apache AGE (A Graph Extension),
|
||||
a PostgreSQL extension that adds graph database capabilities with Cypher support.
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, Any, List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
ASYNCPG_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASYNCPG_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeData:
|
||||
"""Represents a graph node."""
|
||||
id: str
|
||||
properties: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeData:
|
||||
"""Represents a graph edge."""
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: str
|
||||
properties: Dict[str, Any]
|
||||
|
||||
|
||||
class ApacheAGEAdapter:
|
||||
"""
|
||||
Async adapter for Apache AGE graph database operations.
|
||||
|
||||
Provides methods for:
|
||||
- Node CRUD operations
|
||||
- Edge CRUD operations
|
||||
- Cypher query execution
|
||||
- Graph traversal and metrics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 5432,
|
||||
username: str = "postgres",
|
||||
password: str = "password",
|
||||
database: str = "cognee",
|
||||
graph_name: str = "cognee_graph",
|
||||
):
|
||||
"""
|
||||
Initialize the AGE adapter.
|
||||
|
||||
Args:
|
||||
host: PostgreSQL host
|
||||
port: PostgreSQL port
|
||||
username: PostgreSQL username
|
||||
password: PostgreSQL password
|
||||
database: PostgreSQL database name
|
||||
graph_name: AGE graph name (schema)
|
||||
"""
|
||||
if not ASYNCPG_AVAILABLE:
|
||||
raise ImportError(
|
||||
"asyncpg is required. Install with: pip install asyncpg"
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.database = database
|
||||
self.graph_name = graph_name
|
||||
self.pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Create connection pool and initialize AGE.
|
||||
|
||||
Automatically creates the graph if it doesn't exist.
|
||||
"""
|
||||
if self.pool is None:
|
||||
# Connection initialization callback
|
||||
async def init_connection(conn):
|
||||
"""Initialize each connection in the pool with AGE settings."""
|
||||
await conn.execute("LOAD 'age';")
|
||||
await conn.execute("SET search_path = ag_catalog, '$user', public;")
|
||||
|
||||
self.pool = await asyncpg.create_pool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
init=init_connection # Initialize each connection
|
||||
)
|
||||
|
||||
# Initialize AGE extension (only once)
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("CREATE EXTENSION IF NOT EXISTS age;")
|
||||
|
||||
# Create graph if it doesn't exist
|
||||
await self.create_graph_if_not_exists()
|
||||
|
||||
# Create index on id for faster MERGE operations
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(f"""
|
||||
SELECT create_vlabel('{self.graph_name}', 'Node');
|
||||
CREATE INDEX IF NOT EXISTS idx_node_id ON {self.graph_name}.Node(id);
|
||||
""")
|
||||
except:
|
||||
pass
|
||||
|
||||
async def create_graph_if_not_exists(self) -> bool:
|
||||
"""
|
||||
Create the graph if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
True if graph was created, False if it already existed
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
try:
|
||||
await conn.execute(f"SELECT create_graph('{self.graph_name}');")
|
||||
print(f"✓ Created graph: {self.graph_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f"✓ Graph '{self.graph_name}' already exists")
|
||||
return False
|
||||
else:
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connection pool."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
self.pool = None
|
||||
|
||||
async def execute_cypher(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
params: Query parameters
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
if self.pool is None:
|
||||
await self.connect()
|
||||
|
||||
# Wrap Cypher query in AGE SQL syntax
|
||||
if params:
|
||||
param_str = json.dumps(params)
|
||||
wrapped_query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
{query}
|
||||
$$, '{param_str}') as (result agtype);
|
||||
"""
|
||||
else:
|
||||
wrapped_query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
{query}
|
||||
$$) as (result agtype);
|
||||
"""
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
# Connections are already initialized with AGE settings in the pool init callback
|
||||
rows = await conn.fetch(wrapped_query)
|
||||
|
||||
# Parse AGE's agtype results
|
||||
results = []
|
||||
for row in rows:
|
||||
if row['result']:
|
||||
# AGE returns agtype, convert to Python dict
|
||||
result_str = str(row['result'])
|
||||
if result_str and result_str != 'null':
|
||||
try:
|
||||
# Remove AGE-specific type annotations if present
|
||||
# AGE can return things like ::vertex or ::edge
|
||||
if '::' in result_str:
|
||||
result_str = result_str.split('::')[0]
|
||||
|
||||
result_data = json.loads(result_str)
|
||||
results.append(result_data)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, try to extract just the value
|
||||
results.append({'value': result_str})
|
||||
|
||||
return results
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
labels: List[str],
|
||||
properties: Dict[str, Any]
|
||||
) -> NodeData:
|
||||
"""
|
||||
Add a node to the graph. Uses MERGE to avoid duplicates.
|
||||
If a node with the same id already exists, it will be updated with new properties.
|
||||
|
||||
Args:
|
||||
node_id: Unique node identifier
|
||||
labels: Node labels (types)
|
||||
properties: Node properties
|
||||
|
||||
Returns:
|
||||
Created/updated node data
|
||||
"""
|
||||
props = {**properties, 'id': node_id}
|
||||
|
||||
# Build property string manually for AGE compatibility
|
||||
props_parts = []
|
||||
for k, v in props.items():
|
||||
if isinstance(v, str):
|
||||
props_parts.append(f"{k}: '{v}'")
|
||||
elif isinstance(v, bool):
|
||||
props_parts.append(f"{k}: {str(v).lower()}")
|
||||
elif isinstance(v, (int, float)):
|
||||
props_parts.append(f"{k}: {v}")
|
||||
elif v is None:
|
||||
props_parts.append(f"{k}: null")
|
||||
else:
|
||||
# For complex types, convert to JSON string
|
||||
props_parts.append(f"{k}: '{json.dumps(v)}'")
|
||||
|
||||
props_str = ', '.join(props_parts)
|
||||
label_str = ':'.join(labels) if labels else 'Node'
|
||||
|
||||
# Use MERGE to avoid duplicates - matches on id property
|
||||
query = f"""
|
||||
MERGE (n:{label_str} {{id: '{node_id}'}})
|
||||
SET n = {{{props_str}}}
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
return NodeData(
|
||||
id=node_id,
|
||||
properties=props
|
||||
)
|
||||
|
||||
async def add_nodes(self, nodes: List[Tuple[str, List[str], Dict[str, Any]]]) -> None:
|
||||
"""
|
||||
Add multiple nodes in a single batch operation using UNWIND.
|
||||
This is significantly faster than calling add_node() multiple times.
|
||||
Uses MERGE to avoid duplicates.
|
||||
|
||||
Args:
|
||||
nodes: List of tuples (node_id, labels, properties)
|
||||
|
||||
Example:
|
||||
nodes = [
|
||||
("user_1", ["User"], {"name": "Alice", "age": 30}),
|
||||
("user_2", ["User"], {"name": "Bob", "age": 25}),
|
||||
]
|
||||
await adapter.add_nodes(nodes)
|
||||
"""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
# Process in batches of 100
|
||||
BATCH_SIZE = 100
|
||||
|
||||
for i in range(0, len(nodes), BATCH_SIZE):
|
||||
batch = nodes[i:i + BATCH_SIZE]
|
||||
|
||||
node_data_list = []
|
||||
for node_id, labels, properties in batch:
|
||||
props = {"id": node_id, **properties}
|
||||
props_parts = []
|
||||
for k, v in props.items():
|
||||
if isinstance(v, str):
|
||||
props_parts.append(f'{k}: "{v}"')
|
||||
elif isinstance(v, bool):
|
||||
props_parts.append(f'{k}: {str(v).lower()}')
|
||||
elif isinstance(v, (int, float)):
|
||||
props_parts.append(f'{k}: {v}')
|
||||
elif v is None:
|
||||
props_parts.append(f'{k}: null')
|
||||
else:
|
||||
props_parts.append(f'{k}: "{json.dumps(v)}"')
|
||||
props_str = '{' + ', '.join(props_parts) + '}'
|
||||
label_str = ':'.join(labels) if labels else "Node"
|
||||
node_data_list.append(f'{{id: "{node_id}", props: {props_str}, label: "{label_str}"}}')
|
||||
|
||||
unwind_data = '[' + ', '.join(node_data_list) + ']'
|
||||
|
||||
all_prop_keys = set()
|
||||
for node_id, labels, properties in batch:
|
||||
all_prop_keys.update(properties.keys())
|
||||
all_prop_keys.add('id')
|
||||
|
||||
set_clauses = [f"n.{key} = node_data.props.{key}" for key in sorted(all_prop_keys)]
|
||||
set_clause = "SET " + ", ".join(set_clauses)
|
||||
|
||||
common_label = batch[0][1][0] if batch[0][1] else "Node"
|
||||
query = f"""
|
||||
UNWIND {unwind_data} AS node_data
|
||||
MERGE (n {{id: node_data.id}})
|
||||
{set_clause}
|
||||
"""
|
||||
await self.execute_cypher(query)
|
||||
|
||||
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
||||
"""
|
||||
Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
Node data or None if not found
|
||||
"""
|
||||
query = f"MATCH (n {{id: '{node_id}'}}) RETURN n"
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
if results:
|
||||
node_data = results[0]
|
||||
# Extract properties from AGE vertex structure
|
||||
if isinstance(node_data, dict) and 'properties' in node_data:
|
||||
props = node_data['properties']
|
||||
else:
|
||||
props = node_data
|
||||
|
||||
return NodeData(
|
||||
id=props.get('id', node_id),
|
||||
properties=props
|
||||
)
|
||||
return None
|
||||
|
||||
async def delete_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Delete a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = f"MATCH (n {{id: '{node_id}'}}) DETACH DELETE n"
|
||||
await self.execute_cypher(query)
|
||||
return True
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relationship_type: str,
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
) -> EdgeData:
|
||||
"""
|
||||
Add an edge between two nodes. Uses MERGE to avoid duplicates.
|
||||
If an edge already exists between the same nodes with the same type,
|
||||
it will be updated with new properties.
|
||||
|
||||
Args:
|
||||
source_id: Source node ID
|
||||
target_id: Target node ID
|
||||
relationship_type: Relationship type/name
|
||||
properties: Edge properties
|
||||
|
||||
Returns:
|
||||
Created/updated edge data
|
||||
"""
|
||||
props = properties or {}
|
||||
|
||||
# Build property string manually for AGE compatibility
|
||||
if props:
|
||||
props_parts = []
|
||||
for k, v in props.items():
|
||||
if isinstance(v, str):
|
||||
props_parts.append(f"{k}: '{v}'")
|
||||
elif isinstance(v, bool):
|
||||
props_parts.append(f"{k}: {str(v).lower()}")
|
||||
elif isinstance(v, (int, float)):
|
||||
props_parts.append(f"{k}: {v}")
|
||||
elif v is None:
|
||||
props_parts.append(f"{k}: null")
|
||||
else:
|
||||
# For complex types, convert to JSON string
|
||||
props_parts.append(f"{k}: '{json.dumps(v)}'")
|
||||
props_str = ', '.join(props_parts)
|
||||
|
||||
# Use MERGE to avoid duplicate edges
|
||||
query = f"""
|
||||
MATCH (a {{id: '{source_id}'}}), (b {{id: '{target_id}'}})
|
||||
MERGE (a)-[r:{relationship_type}]->(b)
|
||||
SET r = {{{props_str}}}
|
||||
RETURN r
|
||||
"""
|
||||
else:
|
||||
# Use MERGE without properties
|
||||
query = f"""
|
||||
MATCH (a {{id: '{source_id}'}}), (b {{id: '{target_id}'}})
|
||||
MERGE (a)-[r:{relationship_type}]->(b)
|
||||
RETURN r
|
||||
"""
|
||||
|
||||
await self.execute_cypher(query)
|
||||
|
||||
return EdgeData(
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
relationship_type=relationship_type,
|
||||
properties=props
|
||||
)
|
||||
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
|
||||
"""
|
||||
Add multiple edges in a single batch operation using UNWIND.
|
||||
This is significantly faster than calling add_edge() multiple times.
|
||||
Uses MERGE to avoid duplicates.
|
||||
|
||||
Args:
|
||||
edges: List of tuples (source_id, target_id, relationship_type, properties)
|
||||
|
||||
Example:
|
||||
edges = [
|
||||
("user_1", "user_2", "KNOWS", {"since": 2020}),
|
||||
("user_2", "user_3", "FOLLOWS", {"weight": 0.5}),
|
||||
]
|
||||
await adapter.add_edges(edges)
|
||||
"""
|
||||
if not edges:
|
||||
return
|
||||
|
||||
# Group edges by relationship type for efficiency
|
||||
edges_by_type = {}
|
||||
for source_id, target_id, rel_type, properties in edges:
|
||||
if rel_type not in edges_by_type:
|
||||
edges_by_type[rel_type] = []
|
||||
edges_by_type[rel_type].append({
|
||||
"source_id": source_id,
|
||||
"target_id": target_id,
|
||||
"properties": properties or {}
|
||||
})
|
||||
|
||||
# Process each relationship type in batches
|
||||
BATCH_SIZE = 100 # Smaller batches to avoid huge query strings
|
||||
|
||||
for rel_type, edge_list in edges_by_type.items():
|
||||
# Get all unique property keys for this relationship type
|
||||
all_prop_keys = set()
|
||||
for edge in edge_list:
|
||||
all_prop_keys.update(edge["properties"].keys())
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(edge_list), BATCH_SIZE):
|
||||
batch = edge_list[i:i + BATCH_SIZE]
|
||||
|
||||
# Build VALUES clause for batch MERGE
|
||||
values_parts = []
|
||||
for edge in batch:
|
||||
props = edge["properties"]
|
||||
# Build property map for this edge
|
||||
props_cypher_parts = []
|
||||
for key in all_prop_keys:
|
||||
value = props.get(key)
|
||||
if value is None:
|
||||
props_cypher_parts.append(f'{key}: null')
|
||||
elif isinstance(value, str):
|
||||
# Escape quotes in strings
|
||||
escaped = value.replace('"', '\\"')
|
||||
props_cypher_parts.append(f'{key}: "{escaped}"')
|
||||
elif isinstance(value, bool):
|
||||
props_cypher_parts.append(f'{key}: {str(value).lower()}')
|
||||
elif isinstance(value, (int, float)):
|
||||
props_cypher_parts.append(f'{key}: {value}')
|
||||
props_str = ', '.join(props_cypher_parts)
|
||||
|
||||
values_parts.append(f'{{src: "{edge["source_id"]}", tgt: "{edge["target_id"]}", props: {{{props_str}}}}}')
|
||||
|
||||
# Build UNWIND query (AGE requires inline data, not parameters)
|
||||
values_list = '[' + ', '.join(values_parts) + ']'
|
||||
|
||||
# Build SET clause with explicit assignments (AGE doesn't support SET r = map)
|
||||
if all_prop_keys:
|
||||
set_parts = [f'r.{key} = edge.props.{key}' for key in all_prop_keys]
|
||||
set_clause = 'SET ' + ', '.join(set_parts)
|
||||
else:
|
||||
set_clause = ''
|
||||
|
||||
query = f"""
|
||||
UNWIND {values_list} AS edge
|
||||
MATCH (a {{id: edge.src}}), (b {{id: edge.tgt}})
|
||||
MERGE (a)-[r:{rel_type}]->(b)
|
||||
{set_clause}
|
||||
"""
|
||||
|
||||
await self.execute_cypher(query)
|
||||
|
||||
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
||||
"""
|
||||
Get all edges connected to a node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
List of edge data
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (a {{id: '{node_id}'}})-[r]-(b)
|
||||
RETURN {{source: a.id, target: b.id, rel_type: type(r), props: properties(r)}}
|
||||
"""
|
||||
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
edges = []
|
||||
for result in results:
|
||||
# Result is directly the edge data map
|
||||
edges.append(EdgeData(
|
||||
source_id=result.get('source', ''),
|
||||
target_id=result.get('target', ''),
|
||||
relationship_type=result.get('rel_type', ''),
|
||||
properties=result.get('props', {})
|
||||
))
|
||||
|
||||
return edges
|
||||
|
||||
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
||||
"""
|
||||
Get all neighboring nodes.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
List of neighbor nodes
|
||||
"""
|
||||
# Use simple map return instead of full vertex object for better performance
|
||||
query = f"""
|
||||
MATCH (n {{id: '{node_id}'}})-[]-(neighbor)
|
||||
RETURN DISTINCT {{id: neighbor.id, properties: properties(neighbor)}}
|
||||
"""
|
||||
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
neighbors = []
|
||||
for result in results:
|
||||
# Result is already a simple map
|
||||
neighbors.append(NodeData(
|
||||
id=result.get('id', ''),
|
||||
properties=result.get('properties', {})
|
||||
))
|
||||
|
||||
return neighbors
|
||||
|
||||
async def count_nodes(self) -> int:
|
||||
"""
|
||||
Count total nodes in the graph.
|
||||
|
||||
Returns:
|
||||
Number of nodes
|
||||
"""
|
||||
query = "MATCH (n) RETURN {count: count(n)}"
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
if results:
|
||||
return results[0].get('count', 0)
|
||||
return 0
|
||||
|
||||
async def count_edges(self) -> int:
|
||||
"""
|
||||
Count total edges in the graph.
|
||||
|
||||
Returns:
|
||||
Number of edges
|
||||
"""
|
||||
query = "MATCH ()-[r]->() RETURN {count: count(r)}"
|
||||
results = await self.execute_cypher(query)
|
||||
|
||||
if results:
|
||||
return results[0].get('count', 0)
|
||||
return 0
|
||||
|
||||
async def clear_graph(self) -> None:
|
||||
"""
|
||||
Delete all nodes and edges from the graph.
|
||||
|
||||
Note: This only removes the data. The graph schema and tables remain.
|
||||
Use drop_graph() to completely remove the graph including its tables.
|
||||
"""
|
||||
query = "MATCH (n) DETACH DELETE n"
|
||||
await self.execute_cypher(query)
|
||||
|
||||
async def drop_graph(self, recreate: bool = False) -> None:
|
||||
"""
|
||||
Completely drop the graph including its schema and tables from PostgreSQL.
|
||||
|
||||
Args:
|
||||
recreate: If True, recreates an empty graph immediately after dropping
|
||||
|
||||
Warning: This permanently removes all graph data, schema, and tables.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("SET search_path = ag_catalog, '$user', public;")
|
||||
await conn.execute("LOAD 'age';")
|
||||
|
||||
try:
|
||||
await conn.execute(f"SELECT drop_graph('{self.graph_name}', true);")
|
||||
except asyncpg.exceptions.UndefinedObjectError:
|
||||
# Graph doesn't exist, nothing to drop
|
||||
pass
|
||||
except Exception as e:
|
||||
raise Exception(f"Error dropping graph '{self.graph_name}': {e}") from e
|
||||
|
||||
# Recreate if requested
|
||||
if recreate:
|
||||
await self.create_graph_if_not_exists()
|
||||
|
||||
async def list_all_graphs(self) -> List[str]:
|
||||
"""
|
||||
List all Apache AGE graphs in the database.
|
||||
|
||||
Returns:
|
||||
List of graph names
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("SET search_path = ag_catalog, '$user', public;")
|
||||
await conn.execute("LOAD 'age';")
|
||||
|
||||
# Query ag_catalog.ag_graph to get all graphs
|
||||
result = await conn.fetch(
|
||||
"SELECT name FROM ag_catalog.ag_graph ORDER BY name;"
|
||||
)
|
||||
|
||||
return [row['name'] for row in result]
|
||||
|
||||
async def drop_all_graphs(self) -> List[str]:
|
||||
"""
|
||||
Drop ALL Apache AGE graphs in the database.
|
||||
|
||||
Returns:
|
||||
List of dropped graph names
|
||||
|
||||
Warning: This permanently removes ALL graphs from the database!
|
||||
"""
|
||||
graphs = await self.list_all_graphs()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("SET search_path = ag_catalog, '$user', public;")
|
||||
await conn.execute("LOAD 'age';")
|
||||
|
||||
dropped = []
|
||||
for graph_name in graphs:
|
||||
try:
|
||||
await conn.execute(f"SELECT drop_graph('{graph_name}', true);")
|
||||
dropped.append(graph_name)
|
||||
print(f"✓ Dropped graph: {graph_name}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to drop graph '{graph_name}': {e}")
|
||||
|
||||
return dropped
|
||||
|
||||
async def get_stats(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get graph statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with node and edge counts
|
||||
"""
|
||||
num_nodes = await self.count_nodes()
|
||||
num_edges = await self.count_edges()
|
||||
|
||||
return {
|
||||
'nodes': num_nodes,
|
||||
'edges': num_edges,
|
||||
'mean_degree': (2 * num_edges / num_nodes) if num_nodes > 0 else 0
|
||||
}
|
||||
|
||||
|
||||
135
age_adapter/benchmark.py
Normal file
135
age_adapter/benchmark.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import asyncio
|
||||
import time
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from age_adapter.adapter import ApacheAGEAdapter
|
||||
from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class SimpleNode(DataPoint):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def __init__(self, node_id: str, properties: Dict[str, Any]):
|
||||
try:
|
||||
node_uuid = UUID(node_id) if '-' in node_id else UUID(int=hash(node_id) & ((1 << 128) - 1))
|
||||
except:
|
||||
node_uuid = UUID(int=hash(node_id) & ((1 << 128) - 1))
|
||||
super().__init__(id=node_uuid, **properties)
|
||||
|
||||
|
||||
async def main():
|
||||
age_adapter = ApacheAGEAdapter(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
username="cognee",
|
||||
password="cognee",
|
||||
database="cognee_db",
|
||||
graph_name="benchmark_graph"
|
||||
)
|
||||
|
||||
neo4j_adapter = Neo4jAdapter(
|
||||
graph_database_url="bolt://localhost:7687",
|
||||
graph_database_username="neo4j",
|
||||
graph_database_password="pleaseletmein",
|
||||
graph_database_name=None
|
||||
)
|
||||
|
||||
await age_adapter.connect()
|
||||
await neo4j_adapter.initialize()
|
||||
|
||||
batch_size = 3000
|
||||
node_ids = [f"node_{i}" for i in range(batch_size)]
|
||||
nodes = [(nid, ["TestNode"], {"name": f"Node {i}", "value": i})
|
||||
for i, nid in enumerate(node_ids)]
|
||||
|
||||
await age_adapter.drop_graph(recreate=True)
|
||||
await neo4j_adapter.delete_graph()
|
||||
|
||||
start = time.perf_counter()
|
||||
for node_id, labels, props in nodes:
|
||||
await age_adapter.add_node(node_id, labels, props)
|
||||
age_time_single_new = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
for node_id, labels, props in nodes:
|
||||
neo4j_node = SimpleNode(node_id, props)
|
||||
await neo4j_adapter.add_node(neo4j_node)
|
||||
neo4j_time_single_new = time.perf_counter() - start
|
||||
|
||||
print(f"Node Ingestion Single (New): AGE={age_time_single_new:.4f}s, Neo4j={neo4j_time_single_new:.4f}s")
|
||||
|
||||
half = batch_size // 2
|
||||
existing_nodes = nodes[:half]
|
||||
new_node_ids = [f"node_{i}" for i in range(batch_size, batch_size + half)]
|
||||
new_nodes = [(nid, ["TestNode"], {"name": f"Node {i}", "value": i})
|
||||
for i, nid in enumerate(new_node_ids, start=batch_size)]
|
||||
merge_nodes = existing_nodes + new_nodes
|
||||
|
||||
for node_id, labels, props in existing_nodes:
|
||||
await age_adapter.add_node(node_id, labels, props)
|
||||
neo4j_node = SimpleNode(node_id, props)
|
||||
await neo4j_adapter.add_node(neo4j_node)
|
||||
|
||||
start = time.perf_counter()
|
||||
for node_id, labels, props in merge_nodes:
|
||||
await age_adapter.add_node(node_id, labels, props)
|
||||
age_time_single_merge = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
for node_id, labels, props in merge_nodes:
|
||||
neo4j_node = SimpleNode(node_id, props)
|
||||
await neo4j_adapter.add_node(neo4j_node)
|
||||
neo4j_time_single_merge = time.perf_counter() - start
|
||||
|
||||
print(f"Node Ingestion Single (Merge - {half} existing, {len(new_nodes)} new): AGE={age_time_single_merge:.4f}s, Neo4j={neo4j_time_single_merge:.4f}s")
|
||||
|
||||
await age_adapter.drop_graph(recreate=True)
|
||||
await neo4j_adapter.delete_graph()
|
||||
|
||||
start = time.perf_counter()
|
||||
for i in range(0, len(nodes), 100):
|
||||
await age_adapter.add_nodes(nodes[i:i+100])
|
||||
age_time_batch_new = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
for i in range(0, len(nodes), 100):
|
||||
batch = nodes[i:i+100]
|
||||
neo4j_nodes = [SimpleNode(node_id, props) for node_id, _, props in batch]
|
||||
await neo4j_adapter.add_nodes(neo4j_nodes)
|
||||
neo4j_time_batch_new = time.perf_counter() - start
|
||||
|
||||
print(f"Node Ingestion Batch (New): AGE={age_time_batch_new:.4f}s, Neo4j={neo4j_time_batch_new:.4f}s")
|
||||
|
||||
for i in range(0, len(existing_nodes), 100):
|
||||
await age_adapter.add_nodes(existing_nodes[i:i+100])
|
||||
batch = existing_nodes[i:i+100]
|
||||
neo4j_existing = [SimpleNode(node_id, props) for node_id, _, props in batch]
|
||||
await neo4j_adapter.add_nodes(neo4j_existing)
|
||||
|
||||
start = time.perf_counter()
|
||||
for i in range(0, len(merge_nodes), 100):
|
||||
await age_adapter.add_nodes(merge_nodes[i:i+100])
|
||||
age_time_batch_merge = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
for i in range(0, len(merge_nodes), 100):
|
||||
batch = merge_nodes[i:i+100]
|
||||
neo4j_merge_nodes = [SimpleNode(node_id, props) for node_id, _, props in batch]
|
||||
await neo4j_adapter.add_nodes(neo4j_merge_nodes)
|
||||
neo4j_time_batch_merge = time.perf_counter() - start
|
||||
|
||||
print(f"Node Ingestion Batch (Merge - {half} existing, {len(new_nodes)} new): AGE={age_time_batch_merge:.4f}s, Neo4j={neo4j_time_batch_merge:.4f}s")
|
||||
|
||||
await age_adapter.close()
|
||||
await neo4j_adapter.driver.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Loading…
Add table
Reference in a new issue