fix: using graph projection instead of conditions
This commit is contained in:
parent
12ce80005c
commit
6a4d31356b
2 changed files with 418 additions and 497 deletions
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,21 +36,11 @@ async def update_node_access_timestamps(items: List[Any]):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Detect database provider and use appropriate queries
|
# Update nodes using graph projection ( database-agnostic approach
|
||||||
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
await _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms)
|
||||||
|
|
||||||
if provider == "kuzu":
|
|
||||||
await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms)
|
|
||||||
elif provider == "neo4j":
|
|
||||||
await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms)
|
|
||||||
elif provider == "neptune":
|
|
||||||
await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported graph provider: {provider}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find origin documents and update SQL
|
# Find origin documents and update SQL
|
||||||
doc_ids = await _find_origin_documents(graph_engine, node_ids, provider)
|
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
|
||||||
if doc_ids:
|
if doc_ids:
|
||||||
await _update_sql_records(doc_ids, timestamp_dt)
|
await _update_sql_records(doc_ids, timestamp_dt)
|
||||||
|
|
||||||
|
|
@ -57,65 +48,72 @@ async def update_node_access_timestamps(items: List[Any]):
|
||||||
logger.error(f"Failed to update timestamps: {e}")
|
logger.error(f"Failed to update timestamps: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms):
|
async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms):
|
||||||
"""Kuzu-specific node updates"""
|
"""Update nodes using graph projection - works with any graph database"""
|
||||||
|
# Project the graph with necessary properties
|
||||||
|
memory_fragment = CogneeGraph()
|
||||||
|
await memory_fragment.project_graph_from_db(
|
||||||
|
graph_engine,
|
||||||
|
node_properties_to_project=["id"],
|
||||||
|
edge_properties_to_project=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update each node's last_accessed_at property
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
result = await graph_engine.query(
|
node = memory_fragment.get_node(node_id)
|
||||||
"MATCH (n:Node {id: $id}) RETURN n.properties",
|
if node:
|
||||||
{"id": node_id}
|
# Update the node in the database
|
||||||
)
|
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
||||||
|
|
||||||
if result and result[0]:
|
if provider == "kuzu":
|
||||||
props = json.loads(result[0][0]) if result[0][0] else {}
|
# Kuzu stores properties as JSON
|
||||||
props["last_accessed_at"] = timestamp_ms
|
result = await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id}) RETURN n.properties",
|
||||||
|
{"id": node_id}
|
||||||
|
)
|
||||||
|
|
||||||
await graph_engine.query(
|
if result and result[0]:
|
||||||
"MATCH (n:Node {id: $id}) SET n.properties = $props",
|
props = json.loads(result[0][0]) if result[0][0] else {}
|
||||||
{"id": node_id, "props": json.dumps(props)}
|
props["last_accessed_at"] = timestamp_ms
|
||||||
)
|
|
||||||
|
|
||||||
async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms):
|
await graph_engine.query(
|
||||||
"""Neo4j-specific node updates"""
|
"MATCH (n:Node {id: $id}) SET n.properties = $props",
|
||||||
|
{"id": node_id, "props": json.dumps(props)}
|
||||||
|
)
|
||||||
|
elif provider == "neo4j":
|
||||||
|
await graph_engine.query(
|
||||||
|
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp",
|
||||||
|
{"id": node_id, "timestamp": timestamp_ms}
|
||||||
|
)
|
||||||
|
elif provider == "neptune":
|
||||||
|
await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
|
||||||
|
{"id": node_id, "timestamp": timestamp_ms}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _find_origin_documents_via_projection(graph_engine, node_ids):
|
||||||
|
"""Find origin documents using graph projection instead of DB queries"""
|
||||||
|
# Project the entire graph with necessary properties
|
||||||
|
memory_fragment = CogneeGraph()
|
||||||
|
await memory_fragment.project_graph_from_db(
|
||||||
|
graph_engine,
|
||||||
|
node_properties_to_project=["id", "type"],
|
||||||
|
edge_properties_to_project=["relationship_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find origin documents by traversing the in-memory graph
|
||||||
|
doc_ids = set()
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
await graph_engine.query(
|
node = memory_fragment.get_node(node_id)
|
||||||
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp",
|
if node and node.get_attribute("type") == "DocumentChunk":
|
||||||
{"id": node_id, "timestamp": timestamp_ms}
|
# Traverse edges to find connected documents
|
||||||
)
|
for edge in node.get_skeleton_edges():
|
||||||
|
# Get the neighbor node
|
||||||
|
neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node()
|
||||||
|
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
|
||||||
|
doc_ids.add(neighbor.id)
|
||||||
|
|
||||||
async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms):
|
return list(doc_ids)
|
||||||
"""Neptune-specific node updates"""
|
|
||||||
for node_id in node_ids:
|
|
||||||
await graph_engine.query(
|
|
||||||
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
|
|
||||||
{"id": node_id, "timestamp": timestamp_ms}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _find_origin_documents(graph_engine, node_ids, provider):
|
|
||||||
"""Find origin documents with provider-specific queries"""
|
|
||||||
if provider == "kuzu":
|
|
||||||
query = """
|
|
||||||
UNWIND $node_ids AS node_id
|
|
||||||
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
|
|
||||||
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
|
||||||
RETURN DISTINCT doc.id
|
|
||||||
"""
|
|
||||||
elif provider == "neo4j":
|
|
||||||
query = """
|
|
||||||
UNWIND $node_ids AS node_id
|
|
||||||
MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__)
|
|
||||||
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
|
||||||
RETURN DISTINCT doc.id
|
|
||||||
"""
|
|
||||||
elif provider == "neptune":
|
|
||||||
query = """
|
|
||||||
UNWIND $node_ids AS node_id
|
|
||||||
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
|
|
||||||
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
|
||||||
RETURN DISTINCT doc.id
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = await graph_engine.query(query, {"node_ids": node_ids})
|
|
||||||
return list(set([row[0] for row in result if row and row[0]])) if result else []
|
|
||||||
|
|
||||||
async def _update_sql_records(doc_ids, timestamp_dt):
|
async def _update_sql_records(doc_ids, timestamp_dt):
|
||||||
"""Update SQL Data table (same for all providers)"""
|
"""Update SQL Data table (same for all providers)"""
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ from cognee.shared.logging_utils import get_logger
|
||||||
from sqlalchemy import select, or_
|
from sqlalchemy import select, or_
|
||||||
import cognee
|
import cognee
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -100,13 +101,12 @@ async def cleanup_unused_data(
|
||||||
# SQL-based approach: Find unused TextDocuments and use cognee.delete()
|
# SQL-based approach: Find unused TextDocuments and use cognee.delete()
|
||||||
return await _cleanup_via_sql(cutoff_date, dry_run, user_id)
|
return await _cleanup_via_sql(cutoff_date, dry_run, user_id)
|
||||||
else:
|
else:
|
||||||
# Graph-based approach: Find unused nodes directly from graph
|
# Graph-based approach: Find unused nodes using projection (database-agnostic)
|
||||||
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
|
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
|
||||||
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
|
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
|
||||||
|
|
||||||
# Detect database provider and find unused nodes
|
# Find unused nodes using graph projection
|
||||||
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
unused_nodes = await _find_unused_nodes_via_projection(cutoff_timestamp_ms)
|
||||||
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider)
|
|
||||||
|
|
||||||
total_unused = sum(len(nodes) for nodes in unused_nodes.values())
|
total_unused = sum(len(nodes) for nodes in unused_nodes.values())
|
||||||
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
|
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
|
||||||
|
|
@ -130,8 +130,8 @@ async def cleanup_unused_data(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Delete unused nodes with provider-specific logic
|
# Delete unused nodes (provider-agnostic deletion)
|
||||||
deleted_counts = await _delete_unused_nodes(unused_nodes, provider)
|
deleted_counts = await _delete_unused_nodes(unused_nodes)
|
||||||
|
|
||||||
logger.info("Cleanup completed", deleted_counts=deleted_counts)
|
logger.info("Cleanup completed", deleted_counts=deleted_counts)
|
||||||
|
|
||||||
|
|
@ -240,22 +240,14 @@ async def _cleanup_via_sql(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _find_unused_nodes(
|
async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[str, list]:
|
||||||
cutoff_timestamp_ms: int,
|
|
||||||
user_id: Optional[UUID] = None,
|
|
||||||
provider: str = "kuzu"
|
|
||||||
) -> Dict[str, list]:
|
|
||||||
"""
|
"""
|
||||||
Find unused nodes with provider-specific queries.
|
Find unused nodes using graph projection - database-agnostic approach.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
cutoff_timestamp_ms : int
|
cutoff_timestamp_ms : int
|
||||||
Cutoff timestamp in milliseconds since epoch
|
Cutoff timestamp in milliseconds since epoch
|
||||||
user_id : UUID, optional
|
|
||||||
Filter by user ID if provided
|
|
||||||
provider : str
|
|
||||||
Graph database provider (kuzu, neo4j, neptune)
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -264,100 +256,39 @@ async def _find_unused_nodes(
|
||||||
"""
|
"""
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
if provider == "kuzu":
|
# Project the entire graph with necessary properties
|
||||||
return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms)
|
memory_fragment = CogneeGraph()
|
||||||
elif provider == "neo4j":
|
await memory_fragment.project_graph_from_db(
|
||||||
return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms)
|
graph_engine,
|
||||||
elif provider == "neptune":
|
node_properties_to_project=["id", "type", "last_accessed_at"],
|
||||||
return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms)
|
edge_properties_to_project=[]
|
||||||
else:
|
)
|
||||||
logger.warning(f"Unsupported graph provider: {provider}")
|
|
||||||
return {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
|
||||||
|
|
||||||
|
|
||||||
async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms):
|
|
||||||
"""Kuzu-specific unused node detection"""
|
|
||||||
query = "MATCH (n:Node) RETURN n.id, n.type, n.properties"
|
|
||||||
results = await graph_engine.query(query)
|
|
||||||
|
|
||||||
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
||||||
|
|
||||||
for node_id, node_type, props_json in results:
|
# Get all nodes from the projected graph
|
||||||
|
all_nodes = memory_fragment.get_nodes()
|
||||||
|
|
||||||
|
for node in all_nodes:
|
||||||
|
node_type = node.get_attribute("type")
|
||||||
if node_type not in unused_nodes:
|
if node_type not in unused_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if props_json:
|
# Check last_accessed_at property
|
||||||
try:
|
last_accessed = node.get_attribute("last_accessed_at")
|
||||||
props = json.loads(props_json)
|
|
||||||
last_accessed = props.get("last_accessed_at")
|
|
||||||
|
|
||||||
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
|
||||||
unused_nodes[node_type].append(node_id)
|
|
||||||
logger.debug(
|
|
||||||
f"Found unused {node_type}",
|
|
||||||
node_id=node_id,
|
|
||||||
last_accessed=last_accessed
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(f"Failed to parse properties for node {node_id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return unused_nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms):
|
|
||||||
"""Neo4j-specific unused node detection"""
|
|
||||||
query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at"
|
|
||||||
results = await graph_engine.query(query)
|
|
||||||
|
|
||||||
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
|
||||||
|
|
||||||
for row in results:
|
|
||||||
node_id = row["n"]["id"]
|
|
||||||
node_type = row["n"]["type"]
|
|
||||||
last_accessed = row["n"].get("last_accessed_at")
|
|
||||||
|
|
||||||
if node_type not in unused_nodes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
||||||
unused_nodes[node_type].append(node_id)
|
unused_nodes[node_type].append(node.id)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Found unused {node_type}",
|
f"Found unused {node_type}",
|
||||||
node_id=node_id,
|
node_id=node.id,
|
||||||
last_accessed=last_accessed
|
last_accessed=last_accessed
|
||||||
)
|
)
|
||||||
|
|
||||||
return unused_nodes
|
return unused_nodes
|
||||||
|
|
||||||
|
|
||||||
async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms):
|
async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
|
||||||
"""Neptune-specific unused node detection"""
|
|
||||||
query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at"
|
|
||||||
results = await graph_engine.query(query)
|
|
||||||
|
|
||||||
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
|
||||||
|
|
||||||
for row in results:
|
|
||||||
node_id = row["n"]["id"]
|
|
||||||
node_type = row["n"]["type"]
|
|
||||||
last_accessed = row["n"].get("last_accessed_at")
|
|
||||||
|
|
||||||
if node_type not in unused_nodes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
|
||||||
unused_nodes[node_type].append(node_id)
|
|
||||||
logger.debug(
|
|
||||||
f"Found unused {node_type}",
|
|
||||||
node_id=node_id,
|
|
||||||
last_accessed=last_accessed
|
|
||||||
)
|
|
||||||
|
|
||||||
return unused_nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]:
|
|
||||||
"""
|
"""
|
||||||
Delete unused nodes from graph and vector databases.
|
Delete unused nodes from graph and vector databases.
|
||||||
|
|
||||||
|
|
@ -365,8 +296,6 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
|
||||||
----------
|
----------
|
||||||
unused_nodes : Dict[str, list]
|
unused_nodes : Dict[str, list]
|
||||||
Dictionary mapping node types to lists of node IDs to delete
|
Dictionary mapping node types to lists of node IDs to delete
|
||||||
provider : str
|
|
||||||
Graph database provider (kuzu, neo4j, neptune)
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -383,32 +312,26 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
|
||||||
"associations": 0
|
"associations": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Count associations before deletion
|
# Count associations before deletion (using graph projection for consistency)
|
||||||
for node_type, node_ids in unused_nodes.items():
|
if any(unused_nodes.values()):
|
||||||
if not node_ids:
|
memory_fragment = CogneeGraph()
|
||||||
continue
|
await memory_fragment.project_graph_from_db(
|
||||||
|
graph_engine,
|
||||||
|
node_properties_to_project=["id"],
|
||||||
|
edge_properties_to_project=[]
|
||||||
|
)
|
||||||
|
|
||||||
# Count edges connected to these nodes
|
for node_type, node_ids in unused_nodes.items():
|
||||||
for node_id in node_ids:
|
if not node_ids:
|
||||||
if provider == "kuzu":
|
continue
|
||||||
result = await graph_engine.query(
|
|
||||||
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
|
|
||||||
{"id": node_id}
|
|
||||||
)
|
|
||||||
elif provider == "neo4j":
|
|
||||||
result = await graph_engine.query(
|
|
||||||
"MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)",
|
|
||||||
{"id": node_id}
|
|
||||||
)
|
|
||||||
elif provider == "neptune":
|
|
||||||
result = await graph_engine.query(
|
|
||||||
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
|
|
||||||
{"id": node_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and len(result) > 0:
|
# Count edges connected to these nodes
|
||||||
count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"]
|
for node_id in node_ids:
|
||||||
deleted_counts["associations"] += count
|
node = memory_fragment.get_node(node_id)
|
||||||
|
if node:
|
||||||
|
# Count edges from the in-memory graph
|
||||||
|
edge_count = len(node.get_skeleton_edges())
|
||||||
|
deleted_counts["associations"] += edge_count
|
||||||
|
|
||||||
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
|
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
|
||||||
for node_type, node_ids in unused_nodes.items():
|
for node_type, node_ids in unused_nodes.items():
|
||||||
|
|
@ -417,7 +340,7 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) ->
|
||||||
|
|
||||||
logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database")
|
logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database")
|
||||||
|
|
||||||
# Delete nodes in batches
|
# Delete nodes in batches (database-agnostic)
|
||||||
await graph_engine.delete_nodes(node_ids)
|
await graph_engine.delete_nodes(node_ids)
|
||||||
deleted_counts[node_type] = len(node_ids)
|
deleted_counts[node_type] = len(node_ids)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue