fix: generalized queries
This commit is contained in:
parent
5cb6510205
commit
12ce80005c
2 changed files with 516 additions and 409 deletions
|
|
@ -13,21 +13,7 @@ from sqlalchemy import update
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def update_node_access_timestamps(items: List[Any]):
|
async def update_node_access_timestamps(items: List[Any]):
|
||||||
"""
|
|
||||||
Update last_accessed_at for nodes in graph database and corresponding Data records in SQL.
|
|
||||||
|
|
||||||
This function:
|
|
||||||
1. Updates last_accessed_at in the graph database nodes (in properties JSON)
|
|
||||||
2. Traverses to find origin TextDocument nodes (without hardcoded relationship names)
|
|
||||||
3. Updates last_accessed in the SQL Data table for those documents
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
items : List[Any]
|
|
||||||
List of items with payload containing 'id' field (from vector search results)
|
|
||||||
"""
|
|
||||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -49,50 +35,95 @@ async def update_node_access_timestamps(items: List[Any]):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Step 1: Batch update graph nodes
|
# Detect database provider and use appropriate queries
|
||||||
for node_id in node_ids:
|
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
||||||
result = await graph_engine.query(
|
|
||||||
"MATCH (n:Node {id: $id}) RETURN n.properties",
|
if provider == "kuzu":
|
||||||
{"id": node_id}
|
await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms)
|
||||||
|
elif provider == "neo4j":
|
||||||
|
await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms)
|
||||||
|
elif provider == "neptune":
|
||||||
|
await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported graph provider: {provider}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find origin documents and update SQL
|
||||||
|
doc_ids = await _find_origin_documents(graph_engine, node_ids, provider)
|
||||||
|
if doc_ids:
|
||||||
|
await _update_sql_records(doc_ids, timestamp_dt)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update timestamps: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms):
|
||||||
|
"""Kuzu-specific node updates"""
|
||||||
|
for node_id in node_ids:
|
||||||
|
result = await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id}) RETURN n.properties",
|
||||||
|
{"id": node_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result and result[0]:
|
||||||
|
props = json.loads(result[0][0]) if result[0][0] else {}
|
||||||
|
props["last_accessed_at"] = timestamp_ms
|
||||||
|
|
||||||
|
await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id}) SET n.properties = $props",
|
||||||
|
{"id": node_id, "props": json.dumps(props)}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result and result[0]:
|
async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms):
|
||||||
props = json.loads(result[0][0]) if result[0][0] else {}
|
"""Neo4j-specific node updates"""
|
||||||
props["last_accessed_at"] = timestamp_ms
|
for node_id in node_ids:
|
||||||
|
await graph_engine.query(
|
||||||
|
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp",
|
||||||
|
{"id": node_id, "timestamp": timestamp_ms}
|
||||||
|
)
|
||||||
|
|
||||||
await graph_engine.query(
|
async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms):
|
||||||
"MATCH (n:Node {id: $id}) SET n.properties = $props",
|
"""Neptune-specific node updates"""
|
||||||
{"id": node_id, "props": json.dumps(props)}
|
for node_id in node_ids:
|
||||||
)
|
await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
|
||||||
|
{"id": node_id, "timestamp": timestamp_ms}
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Updated access timestamps for {len(node_ids)} graph nodes")
|
async def _find_origin_documents(graph_engine, node_ids, provider):
|
||||||
|
"""Find origin documents with provider-specific queries"""
|
||||||
# Step 2: Find origin TextDocument nodes (without hardcoded relationship names)
|
if provider == "kuzu":
|
||||||
origin_query = """
|
query = """
|
||||||
|
UNWIND $node_ids AS node_id
|
||||||
|
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
|
||||||
|
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
||||||
|
RETURN DISTINCT doc.id
|
||||||
|
"""
|
||||||
|
elif provider == "neo4j":
|
||||||
|
query = """
|
||||||
|
UNWIND $node_ids AS node_id
|
||||||
|
MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__)
|
||||||
|
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
||||||
|
RETURN DISTINCT doc.id
|
||||||
|
"""
|
||||||
|
elif provider == "neptune":
|
||||||
|
query = """
|
||||||
UNWIND $node_ids AS node_id
|
UNWIND $node_ids AS node_id
|
||||||
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
|
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
|
||||||
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
|
||||||
RETURN DISTINCT doc.id
|
RETURN DISTINCT doc.id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await graph_engine.query(origin_query, {"node_ids": node_ids})
|
result = await graph_engine.query(query, {"node_ids": node_ids})
|
||||||
|
return list(set([row[0] for row in result if row and row[0]])) if result else []
|
||||||
|
|
||||||
# Extract and deduplicate document IDs
|
async def _update_sql_records(doc_ids, timestamp_dt):
|
||||||
doc_ids = list(set([row[0] for row in result if row and row[0]])) if result else []
|
"""Update SQL Data table (same for all providers)"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
stmt = update(Data).where(
|
||||||
|
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
|
||||||
|
).values(last_accessed=timestamp_dt)
|
||||||
|
|
||||||
# Step 3: Update SQL Data table
|
await session.execute(stmt)
|
||||||
if doc_ids:
|
await session.commit()
|
||||||
db_engine = get_relational_engine()
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
stmt = update(Data).where(
|
|
||||||
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
|
|
||||||
).values(last_accessed=timestamp_dt)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
logger.debug(f"Updated last_accessed for {len(doc_ids)} Data records in SQL")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to update timestamps: {e}")
|
|
||||||
raise
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Task for automatically deleting unused data from the memify pipeline.
|
Task for automatically deleting unused data from the memify pipeline.
|
||||||
|
|
||||||
This task identifies and removes data (chunks, entities, summaries) that hasn't
|
This task identifies and removes data (chunks, entities, summaries)) that hasn't
|
||||||
been accessed by retrievers for a specified period, helping maintain system
|
been accessed by retrievers for a specified period, helping maintain system
|
||||||
efficiency and storage optimization.
|
efficiency and storage optimization.
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,7 +24,7 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_unused_data(
|
async def cleanup_unused_data(
|
||||||
days_threshold: Optional[int],
|
minutes_threshold: Optional[int],
|
||||||
dry_run: bool = True,
|
dry_run: bool = True,
|
||||||
user_id: Optional[UUID] = None,
|
user_id: Optional[UUID] = None,
|
||||||
text_doc: bool = False
|
text_doc: bool = False
|
||||||
|
|
@ -34,10 +34,10 @@ async def cleanup_unused_data(
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
days_threshold : int
|
minutes_threshold : int
|
||||||
days since last access to consider data unused
|
days since last access to consider data unused
|
||||||
dry_run : bool
|
dry_run : bool
|
||||||
If True, only report what would be deleted without actually deleting (default: True)
|
If True, only report what would be delete without actually deleting (default: True)
|
||||||
user_id : UUID, optional
|
user_id : UUID, optional
|
||||||
Limit cleanup to specific user's data (default: None)
|
Limit cleanup to specific user's data (default: None)
|
||||||
text_doc : bool
|
text_doc : bool
|
||||||
|
|
@ -87,14 +87,14 @@ async def cleanup_unused_data(
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Starting cleanup task",
|
"Starting cleanup task",
|
||||||
days_threshold=days_threshold,
|
minutes_threshold=minutes_threshold,
|
||||||
dry_run=dry_run,
|
dry_run=dry_run,
|
||||||
user_id=str(user_id) if user_id else None,
|
user_id=str(user_id) if user_id else None,
|
||||||
text_doc=text_doc
|
text_doc=text_doc
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate cutoff timestamp
|
# Calculate cutoff timestamp
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_threshold)
|
cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold)
|
||||||
|
|
||||||
if text_doc:
|
if text_doc:
|
||||||
# SQL-based approach: Find unused TextDocuments and use cognee.delete()
|
# SQL-based approach: Find unused TextDocuments and use cognee.delete()
|
||||||
|
|
@ -104,8 +104,9 @@ async def cleanup_unused_data(
|
||||||
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
|
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
|
||||||
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
|
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
|
||||||
|
|
||||||
# Find unused nodes
|
# Detect database provider and find unused nodes
|
||||||
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id)
|
provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
||||||
|
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider)
|
||||||
|
|
||||||
total_unused = sum(len(nodes) for nodes in unused_nodes.values())
|
total_unused = sum(len(nodes) for nodes in unused_nodes.values())
|
||||||
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
|
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
|
||||||
|
|
@ -129,8 +130,8 @@ async def cleanup_unused_data(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Delete unused nodes
|
# Delete unused nodes with provider-specific logic
|
||||||
deleted_counts = await _delete_unused_nodes(unused_nodes)
|
deleted_counts = await _delete_unused_nodes(unused_nodes, provider)
|
||||||
|
|
||||||
logger.info("Cleanup completed", deleted_counts=deleted_counts)
|
logger.info("Cleanup completed", deleted_counts=deleted_counts)
|
||||||
|
|
||||||
|
|
@ -241,10 +242,11 @@ async def _cleanup_via_sql(
|
||||||
|
|
||||||
async def _find_unused_nodes(
|
async def _find_unused_nodes(
|
||||||
cutoff_timestamp_ms: int,
|
cutoff_timestamp_ms: int,
|
||||||
user_id: Optional[UUID] = None
|
user_id: Optional[UUID] = None,
|
||||||
|
provider: str = "kuzu"
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
"""
|
"""
|
||||||
Query Kuzu for nodes with old last_accessed_at timestamps.
|
Find unused nodes with provider-specific queries.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
@ -252,6 +254,8 @@ async def _find_unused_nodes(
|
||||||
Cutoff timestamp in milliseconds since epoch
|
Cutoff timestamp in milliseconds since epoch
|
||||||
user_id : UUID, optional
|
user_id : UUID, optional
|
||||||
Filter by user ID if provided
|
Filter by user ID if provided
|
||||||
|
provider : str
|
||||||
|
Graph database provider (kuzu, neo4j, neptune)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -260,28 +264,33 @@ async def _find_unused_nodes(
|
||||||
"""
|
"""
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
# Query all nodes with their properties
|
if provider == "kuzu":
|
||||||
|
return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms)
|
||||||
|
elif provider == "neo4j":
|
||||||
|
return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms)
|
||||||
|
elif provider == "neptune":
|
||||||
|
return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported graph provider: {provider}")
|
||||||
|
return {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms):
|
||||||
|
"""Kuzu-specific unused node detection"""
|
||||||
query = "MATCH (n:Node) RETURN n.id, n.type, n.properties"
|
query = "MATCH (n:Node) RETURN n.id, n.type, n.properties"
|
||||||
results = await graph_engine.query(query)
|
results = await graph_engine.query(query)
|
||||||
|
|
||||||
unused_nodes = {
|
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
||||||
"DocumentChunk": [],
|
|
||||||
"Entity": [],
|
|
||||||
"TextSummary": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for node_id, node_type, props_json in results:
|
for node_id, node_type, props_json in results:
|
||||||
# Only process tracked node types
|
|
||||||
if node_type not in unused_nodes:
|
if node_type not in unused_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Parse properties JSON
|
|
||||||
if props_json:
|
if props_json:
|
||||||
try:
|
try:
|
||||||
props = json.loads(props_json)
|
props = json.loads(props_json)
|
||||||
last_accessed = props.get("last_accessed_at")
|
last_accessed = props.get("last_accessed_at")
|
||||||
|
|
||||||
# Check if node is unused (never accessed or accessed before cutoff)
|
|
||||||
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
||||||
unused_nodes[node_type].append(node_id)
|
unused_nodes[node_type].append(node_id)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -296,7 +305,59 @@ async def _find_unused_nodes(
|
||||||
return unused_nodes
|
return unused_nodes
|
||||||
|
|
||||||
|
|
||||||
async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
|
async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms):
|
||||||
|
"""Neo4j-specific unused node detection"""
|
||||||
|
query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at"
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
|
||||||
|
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
node_id = row["n"]["id"]
|
||||||
|
node_type = row["n"]["type"]
|
||||||
|
last_accessed = row["n"].get("last_accessed_at")
|
||||||
|
|
||||||
|
if node_type not in unused_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
||||||
|
unused_nodes[node_type].append(node_id)
|
||||||
|
logger.debug(
|
||||||
|
f"Found unused {node_type}",
|
||||||
|
node_id=node_id,
|
||||||
|
last_accessed=last_accessed
|
||||||
|
)
|
||||||
|
|
||||||
|
return unused_nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms):
|
||||||
|
"""Neptune-specific unused node detection"""
|
||||||
|
query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at"
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
|
||||||
|
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
node_id = row["n"]["id"]
|
||||||
|
node_type = row["n"]["type"]
|
||||||
|
last_accessed = row["n"].get("last_accessed_at")
|
||||||
|
|
||||||
|
if node_type not in unused_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
|
||||||
|
unused_nodes[node_type].append(node_id)
|
||||||
|
logger.debug(
|
||||||
|
f"Found unused {node_type}",
|
||||||
|
node_id=node_id,
|
||||||
|
last_accessed=last_accessed
|
||||||
|
)
|
||||||
|
|
||||||
|
return unused_nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Delete unused nodes from graph and vector databases.
|
Delete unused nodes from graph and vector databases.
|
||||||
|
|
||||||
|
|
@ -304,6 +365,8 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
|
||||||
----------
|
----------
|
||||||
unused_nodes : Dict[str, list]
|
unused_nodes : Dict[str, list]
|
||||||
Dictionary mapping node types to lists of node IDs to delete
|
Dictionary mapping node types to lists of node IDs to delete
|
||||||
|
provider : str
|
||||||
|
Graph database provider (kuzu, neo4j, neptune)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -327,12 +390,25 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
|
||||||
|
|
||||||
# Count edges connected to these nodes
|
# Count edges connected to these nodes
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
result = await graph_engine.query(
|
if provider == "kuzu":
|
||||||
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
|
result = await graph_engine.query(
|
||||||
{"id": node_id}
|
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
|
||||||
)
|
{"id": node_id}
|
||||||
|
)
|
||||||
|
elif provider == "neo4j":
|
||||||
|
result = await graph_engine.query(
|
||||||
|
"MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)",
|
||||||
|
{"id": node_id}
|
||||||
|
)
|
||||||
|
elif provider == "neptune":
|
||||||
|
result = await graph_engine.query(
|
||||||
|
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
|
||||||
|
{"id": node_id}
|
||||||
|
)
|
||||||
|
|
||||||
if result and len(result) > 0:
|
if result and len(result) > 0:
|
||||||
deleted_counts["associations"] += result[0][0]
|
count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"]
|
||||||
|
deleted_counts["associations"] += count
|
||||||
|
|
||||||
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
|
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
|
||||||
for node_type, node_ids in unused_nodes.items():
|
for node_type, node_ids in unused_nodes.items():
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue