fix: using graph projection instead of conditions

This commit is contained in:
chinu0609 2025-12-02 18:55:47 +05:30
parent 12ce80005c
commit 6a4d31356b
2 changed files with 418 additions and 497 deletions

View file

@ -10,6 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from sqlalchemy import update from sqlalchemy import update
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
logger = get_logger(__name__) logger = get_logger(__name__)
@ -35,21 +36,11 @@ async def update_node_access_timestamps(items: List[Any]):
return return
try: try:
# Detect database provider and use appropriate queries # Update nodes using graph projection ( database-agnostic approach
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() await _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms)
if provider == "kuzu":
await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms)
elif provider == "neo4j":
await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms)
elif provider == "neptune":
await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms)
else:
logger.warning(f"Unsupported graph provider: {provider}")
return
# Find origin documents and update SQL # Find origin documents and update SQL
doc_ids = await _find_origin_documents(graph_engine, node_ids, provider) doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
if doc_ids: if doc_ids:
await _update_sql_records(doc_ids, timestamp_dt) await _update_sql_records(doc_ids, timestamp_dt)
@ -57,65 +48,72 @@ async def update_node_access_timestamps(items: List[Any]):
logger.error(f"Failed to update timestamps: {e}") logger.error(f"Failed to update timestamps: {e}")
raise raise
async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms): async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms):
"""Kuzu-specific node updates""" """Update nodes using graph projection - works with any graph database"""
# Project the graph with necessary properties
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=["id"],
edge_properties_to_project=[]
)
# Update each node's last_accessed_at property
for node_id in node_ids: for node_id in node_ids:
result = await graph_engine.query( node = memory_fragment.get_node(node_id)
"MATCH (n:Node {id: $id}) RETURN n.properties", if node:
{"id": node_id} # Update the node in the database
) provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
if result and result[0]: if provider == "kuzu":
props = json.loads(result[0][0]) if result[0][0] else {} # Kuzu stores properties as JSON
props["last_accessed_at"] = timestamp_ms result = await graph_engine.query(
"MATCH (n:Node {id: $id}) RETURN n.properties",
{"id": node_id}
)
await graph_engine.query( if result and result[0]:
"MATCH (n:Node {id: $id}) SET n.properties = $props", props = json.loads(result[0][0]) if result[0][0] else {}
{"id": node_id, "props": json.dumps(props)} props["last_accessed_at"] = timestamp_ms
)
async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms): await graph_engine.query(
"""Neo4j-specific node updates""" "MATCH (n:Node {id: $id}) SET n.properties = $props",
{"id": node_id, "props": json.dumps(props)}
)
elif provider == "neo4j":
await graph_engine.query(
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp",
{"id": node_id, "timestamp": timestamp_ms}
)
elif provider == "neptune":
await graph_engine.query(
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
{"id": node_id, "timestamp": timestamp_ms}
)
async def _find_origin_documents_via_projection(graph_engine, node_ids):
"""Find origin documents using graph projection instead of DB queries"""
# Project the entire graph with necessary properties
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=["id", "type"],
edge_properties_to_project=["relationship_name"]
)
# Find origin documents by traversing the in-memory graph
doc_ids = set()
for node_id in node_ids: for node_id in node_ids:
await graph_engine.query( node = memory_fragment.get_node(node_id)
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", if node and node.get_attribute("type") == "DocumentChunk":
{"id": node_id, "timestamp": timestamp_ms} # Traverse edges to find connected documents
) for edge in node.get_skeleton_edges():
# Get the neighbor node
neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node()
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
doc_ids.add(neighbor.id)
async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms): return list(doc_ids)
"""Neptune-specific node updates"""
for node_id in node_ids:
await graph_engine.query(
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
{"id": node_id, "timestamp": timestamp_ms}
)
async def _find_origin_documents(graph_engine, node_ids, provider):
"""Find origin documents with provider-specific queries"""
if provider == "kuzu":
query = """
UNWIND $node_ids AS node_id
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id
"""
elif provider == "neo4j":
query = """
UNWIND $node_ids AS node_id
MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id
"""
elif provider == "neptune":
query = """
UNWIND $node_ids AS node_id
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id
"""
result = await graph_engine.query(query, {"node_ids": node_ids})
return list(set([row[0] for row in result if row and row[0]])) if result else []
async def _update_sql_records(doc_ids, timestamp_dt): async def _update_sql_records(doc_ids, timestamp_dt):
"""Update SQL Data table (same for all providers)""" """Update SQL Data table (same for all providers)"""

View file

@ -19,6 +19,7 @@ from cognee.shared.logging_utils import get_logger
from sqlalchemy import select, or_ from sqlalchemy import select, or_
import cognee import cognee
import sqlalchemy as sa import sqlalchemy as sa
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
logger = get_logger(__name__) logger = get_logger(__name__)
@ -100,13 +101,12 @@ async def cleanup_unused_data(
# SQL-based approach: Find unused TextDocuments and use cognee.delete() # SQL-based approach: Find unused TextDocuments and use cognee.delete()
return await _cleanup_via_sql(cutoff_date, dry_run, user_id) return await _cleanup_via_sql(cutoff_date, dry_run, user_id)
else: else:
# Graph-based approach: Find unused nodes directly from graph # Graph-based approach: Find unused nodes using projection (database-agnostic)
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
# Detect database provider and find unused nodes # Find unused nodes using graph projection
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() unused_nodes = await _find_unused_nodes_via_projection(cutoff_timestamp_ms)
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider)
total_unused = sum(len(nodes) for nodes in unused_nodes.values()) total_unused = sum(len(nodes) for nodes in unused_nodes.values())
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
@ -130,8 +130,8 @@ async def cleanup_unused_data(
} }
} }
# Delete unused nodes with provider-specific logic # Delete unused nodes (provider-agnostic deletion)
deleted_counts = await _delete_unused_nodes(unused_nodes, provider) deleted_counts = await _delete_unused_nodes(unused_nodes)
logger.info("Cleanup completed", deleted_counts=deleted_counts) logger.info("Cleanup completed", deleted_counts=deleted_counts)
@ -240,22 +240,14 @@ async def _cleanup_via_sql(
} }
async def _find_unused_nodes( async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[str, list]:
cutoff_timestamp_ms: int,
user_id: Optional[UUID] = None,
provider: str = "kuzu"
) -> Dict[str, list]:
""" """
Find unused nodes with provider-specific queries. Find unused nodes using graph projection - database-agnostic approach.
Parameters Parameters
---------- ----------
cutoff_timestamp_ms : int cutoff_timestamp_ms : int
Cutoff timestamp in milliseconds since epoch Cutoff timestamp in milliseconds since epoch
user_id : UUID, optional
Filter by user ID if provided
provider : str
Graph database provider (kuzu, neo4j, neptune)
Returns Returns
------- -------
@ -264,100 +256,39 @@ async def _find_unused_nodes(
""" """
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
if provider == "kuzu": # Project the entire graph with necessary properties
return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms) memory_fragment = CogneeGraph()
elif provider == "neo4j": await memory_fragment.project_graph_from_db(
return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms) graph_engine,
elif provider == "neptune": node_properties_to_project=["id", "type", "last_accessed_at"],
return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms) edge_properties_to_project=[]
else: )
logger.warning(f"Unsupported graph provider: {provider}")
return {"DocumentChunk": [], "Entity": [], "TextSummary": []}
async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms):
"""Kuzu-specific unused node detection"""
query = "MATCH (n:Node) RETURN n.id, n.type, n.properties"
results = await graph_engine.query(query)
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
for node_id, node_type, props_json in results: # Get all nodes from the projected graph
all_nodes = memory_fragment.get_nodes()
for node in all_nodes:
node_type = node.get_attribute("type")
if node_type not in unused_nodes: if node_type not in unused_nodes:
continue continue
if props_json: # Check last_accessed_at property
try: last_accessed = node.get_attribute("last_accessed_at")
props = json.loads(props_json)
last_accessed = props.get("last_accessed_at")
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id)
logger.debug(
f"Found unused {node_type}",
node_id=node_id,
last_accessed=last_accessed
)
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties for node {node_id}")
continue
return unused_nodes
async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms):
"""Neo4j-specific unused node detection"""
query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at"
results = await graph_engine.query(query)
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
for row in results:
node_id = row["n"]["id"]
node_type = row["n"]["type"]
last_accessed = row["n"].get("last_accessed_at")
if node_type not in unused_nodes:
continue
if last_accessed is None or last_accessed < cutoff_timestamp_ms: if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id) unused_nodes[node_type].append(node.id)
logger.debug( logger.debug(
f"Found unused {node_type}", f"Found unused {node_type}",
node_id=node_id, node_id=node.id,
last_accessed=last_accessed last_accessed=last_accessed
) )
return unused_nodes return unused_nodes
async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms): async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
"""Neptune-specific unused node detection"""
query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at"
results = await graph_engine.query(query)
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
for row in results:
node_id = row["n"]["id"]
node_type = row["n"]["type"]
last_accessed = row["n"].get("last_accessed_at")
if node_type not in unused_nodes:
continue
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id)
logger.debug(
f"Found unused {node_type}",
node_id=node_id,
last_accessed=last_accessed
)
return unused_nodes
async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]:
""" """
Delete unused nodes from graph and vector databases. Delete unused nodes from graph and vector databases.
@ -365,8 +296,6 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
---------- ----------
unused_nodes : Dict[str, list] unused_nodes : Dict[str, list]
Dictionary mapping node types to lists of node IDs to delete Dictionary mapping node types to lists of node IDs to delete
provider : str
Graph database provider (kuzu, neo4j, neptune)
Returns Returns
------- -------
@ -383,32 +312,26 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
"associations": 0 "associations": 0
} }
# Count associations before deletion # Count associations before deletion (using graph projection for consistency)
for node_type, node_ids in unused_nodes.items(): if any(unused_nodes.values()):
if not node_ids: memory_fragment = CogneeGraph()
continue await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=["id"],
edge_properties_to_project=[]
)
# Count edges connected to these nodes for node_type, node_ids in unused_nodes.items():
for node_id in node_ids: if not node_ids:
if provider == "kuzu": continue
result = await graph_engine.query(
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
{"id": node_id}
)
elif provider == "neo4j":
result = await graph_engine.query(
"MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)",
{"id": node_id}
)
elif provider == "neptune":
result = await graph_engine.query(
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
{"id": node_id}
)
if result and len(result) > 0: # Count edges connected to these nodes
count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"] for node_id in node_ids:
deleted_counts["associations"] += count node = memory_fragment.get_node(node_id)
if node:
# Count edges from the in-memory graph
edge_count = len(node.get_skeleton_edges())
deleted_counts["associations"] += edge_count
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed) # Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
for node_type, node_ids in unused_nodes.items(): for node_type, node_ids in unused_nodes.items():
@ -417,7 +340,7 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database")
# Delete nodes in batches # Delete nodes in batches (database-agnostic)
await graph_engine.delete_nodes(node_ids) await graph_engine.delete_nodes(node_ids)
deleted_counts[node_type] = len(node_ids) deleted_counts[node_type] = len(node_ids)