fix: generalized queries

This commit is contained in:
chinu0609 2025-11-26 17:32:50 +05:30
parent 5cb6510205
commit 12ce80005c
2 changed files with 516 additions and 409 deletions

View file

@ -13,21 +13,7 @@ from sqlalchemy import update
logger = get_logger(__name__) logger = get_logger(__name__)
async def update_node_access_timestamps(items: List[Any]): async def update_node_access_timestamps(items: List[Any]):
"""
Update last_accessed_at for nodes in graph database and corresponding Data records in SQL.
This function:
1. Updates last_accessed_at in the graph database nodes (in properties JSON)
2. Traverses to find origin TextDocument nodes (without hardcoded relationship names)
3. Updates last_accessed in the SQL Data table for those documents
Parameters
----------
items : List[Any]
List of items with payload containing 'id' field (from vector search results)
"""
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
return return
@ -49,50 +35,95 @@ async def update_node_access_timestamps(items: List[Any]):
return return
try: try:
# Step 1: Batch update graph nodes # Detect database provider and use appropriate queries
for node_id in node_ids: provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
result = await graph_engine.query(
"MATCH (n:Node {id: $id}) RETURN n.properties", if provider == "kuzu":
{"id": node_id} await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms)
elif provider == "neo4j":
await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms)
elif provider == "neptune":
await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms)
else:
logger.warning(f"Unsupported graph provider: {provider}")
return
# Find origin documents and update SQL
doc_ids = await _find_origin_documents(graph_engine, node_ids, provider)
if doc_ids:
await _update_sql_records(doc_ids, timestamp_dt)
except Exception as e:
logger.error(f"Failed to update timestamps: {e}")
raise
async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms):
"""Kuzu-specific node updates"""
for node_id in node_ids:
result = await graph_engine.query(
"MATCH (n:Node {id: $id}) RETURN n.properties",
{"id": node_id}
)
if result and result[0]:
props = json.loads(result[0][0]) if result[0][0] else {}
props["last_accessed_at"] = timestamp_ms
await graph_engine.query(
"MATCH (n:Node {id: $id}) SET n.properties = $props",
{"id": node_id, "props": json.dumps(props)}
) )
if result and result[0]: async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms):
props = json.loads(result[0][0]) if result[0][0] else {} """Neo4j-specific node updates"""
props["last_accessed_at"] = timestamp_ms for node_id in node_ids:
await graph_engine.query(
"MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp",
{"id": node_id, "timestamp": timestamp_ms}
)
await graph_engine.query( async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms):
"MATCH (n:Node {id: $id}) SET n.properties = $props", """Neptune-specific node updates"""
{"id": node_id, "props": json.dumps(props)} for node_id in node_ids:
) await graph_engine.query(
"MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp",
{"id": node_id, "timestamp": timestamp_ms}
)
logger.debug(f"Updated access timestamps for {len(node_ids)} graph nodes") async def _find_origin_documents(graph_engine, node_ids, provider):
"""Find origin documents with provider-specific queries"""
# Step 2: Find origin TextDocument nodes (without hardcoded relationship names) if provider == "kuzu":
origin_query = """ query = """
UNWIND $node_ids AS node_id
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id
"""
elif provider == "neo4j":
query = """
UNWIND $node_ids AS node_id
MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id
"""
elif provider == "neptune":
query = """
UNWIND $node_ids AS node_id UNWIND $node_ids AS node_id
MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node)
WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document']
RETURN DISTINCT doc.id RETURN DISTINCT doc.id
""" """
result = await graph_engine.query(origin_query, {"node_ids": node_ids}) result = await graph_engine.query(query, {"node_ids": node_ids})
return list(set([row[0] for row in result if row and row[0]])) if result else []
# Extract and deduplicate document IDs async def _update_sql_records(doc_ids, timestamp_dt):
doc_ids = list(set([row[0] for row in result if row and row[0]])) if result else [] """Update SQL Data table (same for all providers)"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = update(Data).where(
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
).values(last_accessed=timestamp_dt)
# Step 3: Update SQL Data table await session.execute(stmt)
if doc_ids: await session.commit()
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = update(Data).where(
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
).values(last_accessed=timestamp_dt)
await session.execute(stmt)
await session.commit()
logger.debug(f"Updated last_accessed for {len(doc_ids)} Data records in SQL")
except Exception as e:
logger.error(f"Failed to update timestamps: {e}")
raise

View file

@ -1,7 +1,7 @@
""" """
Task for automatically deleting unused data from the memify pipeline. Task for automatically deleting unused data from the memify pipeline.
This task identifies and removes data (chunks, entities, summaries) that hasn't This task identifies and removes data (chunks, entities, summaries)) that hasn't
been accessed by retrievers for a specified period, helping maintain system been accessed by retrievers for a specified period, helping maintain system
efficiency and storage optimization. efficiency and storage optimization.
""" """
@ -24,7 +24,7 @@ logger = get_logger(__name__)
async def cleanup_unused_data( async def cleanup_unused_data(
days_threshold: Optional[int], minutes_threshold: Optional[int],
dry_run: bool = True, dry_run: bool = True,
user_id: Optional[UUID] = None, user_id: Optional[UUID] = None,
text_doc: bool = False text_doc: bool = False
@ -34,10 +34,10 @@ async def cleanup_unused_data(
Parameters Parameters
---------- ----------
days_threshold : int minutes_threshold : int
days since last access to consider data unused days since last access to consider data unused
dry_run : bool dry_run : bool
If True, only report what would be deleted without actually deleting (default: True) If True, only report what would be delete without actually deleting (default: True)
user_id : UUID, optional user_id : UUID, optional
Limit cleanup to specific user's data (default: None) Limit cleanup to specific user's data (default: None)
text_doc : bool text_doc : bool
@ -87,14 +87,14 @@ async def cleanup_unused_data(
logger.info( logger.info(
"Starting cleanup task", "Starting cleanup task",
days_threshold=days_threshold, minutes_threshold=minutes_threshold,
dry_run=dry_run, dry_run=dry_run,
user_id=str(user_id) if user_id else None, user_id=str(user_id) if user_id else None,
text_doc=text_doc text_doc=text_doc
) )
# Calculate cutoff timestamp # Calculate cutoff timestamp
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_threshold) cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold)
if text_doc: if text_doc:
# SQL-based approach: Find unused TextDocuments and use cognee.delete() # SQL-based approach: Find unused TextDocuments and use cognee.delete()
@ -104,8 +104,9 @@ async def cleanup_unused_data(
cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000)
logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)")
# Find unused nodes # Detect database provider and find unused nodes
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id) provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider)
total_unused = sum(len(nodes) for nodes in unused_nodes.values()) total_unused = sum(len(nodes) for nodes in unused_nodes.values())
logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()})
@ -129,8 +130,8 @@ async def cleanup_unused_data(
} }
} }
# Delete unused nodes # Delete unused nodes with provider-specific logic
deleted_counts = await _delete_unused_nodes(unused_nodes) deleted_counts = await _delete_unused_nodes(unused_nodes, provider)
logger.info("Cleanup completed", deleted_counts=deleted_counts) logger.info("Cleanup completed", deleted_counts=deleted_counts)
@ -241,10 +242,11 @@ async def _cleanup_via_sql(
async def _find_unused_nodes( async def _find_unused_nodes(
cutoff_timestamp_ms: int, cutoff_timestamp_ms: int,
user_id: Optional[UUID] = None user_id: Optional[UUID] = None,
provider: str = "kuzu"
) -> Dict[str, list]: ) -> Dict[str, list]:
""" """
Query Kuzu for nodes with old last_accessed_at timestamps. Find unused nodes with provider-specific queries.
Parameters Parameters
---------- ----------
@ -252,6 +254,8 @@ async def _find_unused_nodes(
Cutoff timestamp in milliseconds since epoch Cutoff timestamp in milliseconds since epoch
user_id : UUID, optional user_id : UUID, optional
Filter by user ID if provided Filter by user ID if provided
provider : str
Graph database provider (kuzu, neo4j, neptune)
Returns Returns
------- -------
@ -260,28 +264,33 @@ async def _find_unused_nodes(
""" """
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
# Query all nodes with their properties if provider == "kuzu":
return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms)
elif provider == "neo4j":
return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms)
elif provider == "neptune":
return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms)
else:
logger.warning(f"Unsupported graph provider: {provider}")
return {"DocumentChunk": [], "Entity": [], "TextSummary": []}
async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms):
"""Kuzu-specific unused node detection"""
query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" query = "MATCH (n:Node) RETURN n.id, n.type, n.properties"
results = await graph_engine.query(query) results = await graph_engine.query(query)
unused_nodes = { unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
"DocumentChunk": [],
"Entity": [],
"TextSummary": []
}
for node_id, node_type, props_json in results: for node_id, node_type, props_json in results:
# Only process tracked node types
if node_type not in unused_nodes: if node_type not in unused_nodes:
continue continue
# Parse properties JSON
if props_json: if props_json:
try: try:
props = json.loads(props_json) props = json.loads(props_json)
last_accessed = props.get("last_accessed_at") last_accessed = props.get("last_accessed_at")
# Check if node is unused (never accessed or accessed before cutoff)
if last_accessed is None or last_accessed < cutoff_timestamp_ms: if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id) unused_nodes[node_type].append(node_id)
logger.debug( logger.debug(
@ -296,7 +305,59 @@ async def _find_unused_nodes(
return unused_nodes return unused_nodes
async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms):
"""Neo4j-specific unused node detection"""
query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at"
results = await graph_engine.query(query)
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
for row in results:
node_id = row["n"]["id"]
node_type = row["n"]["type"]
last_accessed = row["n"].get("last_accessed_at")
if node_type not in unused_nodes:
continue
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id)
logger.debug(
f"Found unused {node_type}",
node_id=node_id,
last_accessed=last_accessed
)
return unused_nodes
async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms):
"""Neptune-specific unused node detection"""
query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at"
results = await graph_engine.query(query)
unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []}
for row in results:
node_id = row["n"]["id"]
node_type = row["n"]["type"]
last_accessed = row["n"].get("last_accessed_at")
if node_type not in unused_nodes:
continue
if last_accessed is None or last_accessed < cutoff_timestamp_ms:
unused_nodes[node_type].append(node_id)
logger.debug(
f"Found unused {node_type}",
node_id=node_id,
last_accessed=last_accessed
)
return unused_nodes
async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]:
""" """
Delete unused nodes from graph and vector databases. Delete unused nodes from graph and vector databases.
@ -304,6 +365,8 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
---------- ----------
unused_nodes : Dict[str, list] unused_nodes : Dict[str, list]
Dictionary mapping node types to lists of node IDs to delete Dictionary mapping node types to lists of node IDs to delete
provider : str
Graph database provider (kuzu, neo4j, neptune)
Returns Returns
------- -------
@ -327,12 +390,25 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]:
# Count edges connected to these nodes # Count edges connected to these nodes
for node_id in node_ids: for node_id in node_ids:
result = await graph_engine.query( if provider == "kuzu":
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", result = await graph_engine.query(
{"id": node_id} "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
) {"id": node_id}
)
elif provider == "neo4j":
result = await graph_engine.query(
"MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)",
{"id": node_id}
)
elif provider == "neptune":
result = await graph_engine.query(
"MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)",
{"id": node_id}
)
if result and len(result) > 0: if result and len(result) > 0:
deleted_counts["associations"] += result[0][0] count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"]
deleted_counts["associations"] += count
# Delete from graph database (uses DETACH DELETE, so edges are automatically removed) # Delete from graph database (uses DETACH DELETE, so edges are automatically removed)
for node_type, node_ids in unused_nodes.items(): for node_type, node_ids in unused_nodes.items():