chore: ruff format and refactor on contributor PR

This commit is contained in:
Igor Ilic 2026-01-13 15:10:21 +01:00
parent 0d2f66fa1d
commit dce51efbe3
3 changed files with 408 additions and 387 deletions

View file

@ -10,20 +10,20 @@ logger = get_logger("extract_usage_frequency")
async def extract_usage_frequency( async def extract_usage_frequency(
subgraphs: List[CogneeGraph], subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1 min_interaction_threshold: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Extract usage frequency from CogneeUserInteraction nodes. Extract usage frequency from CogneeUserInteraction nodes.
When save_interaction=True in cognee.search(), the system creates: When save_interaction=True in cognee.search(), the system creates:
- CogneeUserInteraction nodes (representing the query/answer interaction) - CogneeUserInteraction nodes (representing the query/answer interaction)
- used_graph_element_to_answer edges (connecting interactions to graph elements used) - used_graph_element_to_answer edges (connecting interactions to graph elements used)
This function tallies how often each graph element is referenced via these edges, This function tallies how often each graph element is referenced via these edges,
enabling frequency-based ranking in downstream retrievers. enabling frequency-based ranking in downstream retrievers.
:param subgraphs: List of CogneeGraph instances containing interaction data :param subgraphs: List of CogneeGraph instances containing interaction data
:param time_window: Time window to consider for interactions (default: 7 days) :param time_window: Time window to consider for interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1) :param min_interaction_threshold: Minimum interactions to track (default: 1)
@ -31,33 +31,35 @@ async def extract_usage_frequency(
""" """
current_time = datetime.now() current_time = datetime.now()
cutoff_time = current_time - time_window cutoff_time = current_time - time_window
# Track frequencies for graph elements (nodes and edges) # Track frequencies for graph elements (nodes and edges)
node_frequencies = {} node_frequencies = {}
edge_frequencies = {} edge_frequencies = {}
relationship_type_frequencies = {} relationship_type_frequencies = {}
# Track interaction metadata # Track interaction metadata
interaction_count = 0 interaction_count = 0
interactions_in_window = 0 interactions_in_window = 0
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs") logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}") logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
for subgraph in subgraphs: for subgraph in subgraphs:
# Find all CogneeUserInteraction nodes # Find all CogneeUserInteraction nodes
interaction_nodes = {} interaction_nodes = {}
for node_id, node in subgraph.nodes.items(): for node_id, node in subgraph.nodes.items():
node_type = node.attributes.get('type') or node.attributes.get('node_type') node_type = node.attributes.get("type") or node.attributes.get("node_type")
if node_type == 'CogneeUserInteraction': if node_type == "CogneeUserInteraction":
# Parse and validate timestamp # Parse and validate timestamp
timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at') timestamp_value = node.attributes.get("timestamp") or node.attributes.get(
"created_at"
)
if timestamp_value is not None: if timestamp_value is not None:
try: try:
# Handle various timestamp formats # Handle various timestamp formats
interaction_time = None interaction_time = None
if isinstance(timestamp_value, datetime): if isinstance(timestamp_value, datetime):
# Already a Python datetime # Already a Python datetime
interaction_time = timestamp_value interaction_time = timestamp_value
@ -81,24 +83,24 @@ async def extract_usage_frequency(
else: else:
# ISO format string # ISO format string
interaction_time = datetime.fromisoformat(timestamp_value) interaction_time = datetime.fromisoformat(timestamp_value)
elif hasattr(timestamp_value, 'to_native'): elif hasattr(timestamp_value, "to_native"):
# Neo4j datetime object - convert to Python datetime # Neo4j datetime object - convert to Python datetime
interaction_time = timestamp_value.to_native() interaction_time = timestamp_value.to_native()
elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'): elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"):
# Datetime-like object - extract components # Datetime-like object - extract components
try: try:
interaction_time = datetime( interaction_time = datetime(
year=timestamp_value.year, year=timestamp_value.year,
month=timestamp_value.month, month=timestamp_value.month,
day=timestamp_value.day, day=timestamp_value.day,
hour=getattr(timestamp_value, 'hour', 0), hour=getattr(timestamp_value, "hour", 0),
minute=getattr(timestamp_value, 'minute', 0), minute=getattr(timestamp_value, "minute", 0),
second=getattr(timestamp_value, 'second', 0), second=getattr(timestamp_value, "second", 0),
microsecond=getattr(timestamp_value, 'microsecond', 0) microsecond=getattr(timestamp_value, "microsecond", 0),
) )
except (AttributeError, ValueError): except (AttributeError, ValueError):
pass pass
if interaction_time is None: if interaction_time is None:
# Last resort: try converting to string and parsing # Last resort: try converting to string and parsing
str_value = str(timestamp_value) str_value = str(timestamp_value)
@ -110,73 +112,83 @@ async def extract_usage_frequency(
interaction_time = datetime.fromtimestamp(ts_int) interaction_time = datetime.fromtimestamp(ts_int)
else: else:
interaction_time = datetime.fromisoformat(str_value) interaction_time = datetime.fromisoformat(str_value)
if interaction_time is None: if interaction_time is None:
raise ValueError(f"Could not parse timestamp: {timestamp_value}") raise ValueError(f"Could not parse timestamp: {timestamp_value}")
# Make sure it's timezone-naive for comparison # Make sure it's timezone-naive for comparison
if interaction_time.tzinfo is not None: if interaction_time.tzinfo is not None:
interaction_time = interaction_time.replace(tzinfo=None) interaction_time = interaction_time.replace(tzinfo=None)
interaction_nodes[node_id] = { interaction_nodes[node_id] = {
'node': node, "node": node,
'timestamp': interaction_time, "timestamp": interaction_time,
'in_window': interaction_time >= cutoff_time "in_window": interaction_time >= cutoff_time,
} }
interaction_count += 1 interaction_count += 1
if interaction_time >= cutoff_time: if interaction_time >= cutoff_time:
interactions_in_window += 1 interactions_in_window += 1
except (ValueError, TypeError, AttributeError, OSError) as e: except (ValueError, TypeError, AttributeError, OSError) as e:
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}") logger.warning(
logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}") f"Failed to parse timestamp for interaction node {node_id}: {e}"
)
logger.debug(
f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}"
)
# Process edges to find graph elements used in interactions # Process edges to find graph elements used in interactions
for edge in subgraph.edges: for edge in subgraph.edges:
relationship_type = edge.attributes.get('relationship_type') relationship_type = edge.attributes.get("relationship_type")
# Look for 'used_graph_element_to_answer' edges # Look for 'used_graph_element_to_answer' edges
if relationship_type == 'used_graph_element_to_answer': if relationship_type == "used_graph_element_to_answer":
# node1 should be the CogneeUserInteraction, node2 is the graph element # node1 should be the CogneeUserInteraction, node2 is the graph element
source_id = str(edge.node1.id) source_id = str(edge.node1.id)
target_id = str(edge.node2.id) target_id = str(edge.node2.id)
# Check if source is an interaction node in our time window # Check if source is an interaction node in our time window
if source_id in interaction_nodes: if source_id in interaction_nodes:
interaction_data = interaction_nodes[source_id] interaction_data = interaction_nodes[source_id]
if interaction_data['in_window']: if interaction_data["in_window"]:
# Count the graph element (target node) being used # Count the graph element (target node) being used
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1 node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
# Also track what type of element it is for analytics # Also track what type of element it is for analytics
target_node = subgraph.get_node(target_id) target_node = subgraph.get_node(target_id)
if target_node: if target_node:
element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type') element_type = target_node.attributes.get(
"type"
) or target_node.attributes.get("node_type")
if element_type: if element_type:
relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1 relationship_type_frequencies[element_type] = (
relationship_type_frequencies.get(element_type, 0) + 1
)
# Also track general edge usage patterns # Also track general edge usage patterns
elif relationship_type and relationship_type != 'used_graph_element_to_answer': elif relationship_type and relationship_type != "used_graph_element_to_answer":
# Check if either endpoint is referenced in a recent interaction # Check if either endpoint is referenced in a recent interaction
source_id = str(edge.node1.id) source_id = str(edge.node1.id)
target_id = str(edge.node2.id) target_id = str(edge.node2.id)
# If this edge connects to any frequently accessed nodes, track the edge type # If this edge connects to any frequently accessed nodes, track the edge type
if source_id in node_frequencies or target_id in node_frequencies: if source_id in node_frequencies or target_id in node_frequencies:
edge_key = f"{relationship_type}:{source_id}:{target_id}" edge_key = f"{relationship_type}:{source_id}:{target_id}"
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1 edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
# Filter frequencies above threshold # Filter frequencies above threshold
filtered_node_frequencies = { filtered_node_frequencies = {
node_id: freq for node_id, freq in node_frequencies.items() node_id: freq
for node_id, freq in node_frequencies.items()
if freq >= min_interaction_threshold if freq >= min_interaction_threshold
} }
filtered_edge_frequencies = { filtered_edge_frequencies = {
edge_key: freq for edge_key, freq in edge_frequencies.items() edge_key: freq
for edge_key, freq in edge_frequencies.items()
if freq >= min_interaction_threshold if freq >= min_interaction_threshold
} }
logger.info( logger.info(
f"Processed {interactions_in_window}/{interaction_count} interactions in time window" f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
) )
@ -185,58 +197,59 @@ async def extract_usage_frequency(
f"above threshold (min: {min_interaction_threshold})" f"above threshold (min: {min_interaction_threshold})"
) )
logger.info(f"Element type distribution: {relationship_type_frequencies}") logger.info(f"Element type distribution: {relationship_type_frequencies}")
return { return {
'node_frequencies': filtered_node_frequencies, "node_frequencies": filtered_node_frequencies,
'edge_frequencies': filtered_edge_frequencies, "edge_frequencies": filtered_edge_frequencies,
'element_type_frequencies': relationship_type_frequencies, "element_type_frequencies": relationship_type_frequencies,
'total_interactions': interaction_count, "total_interactions": interaction_count,
'interactions_in_window': interactions_in_window, "interactions_in_window": interactions_in_window,
'time_window_days': time_window.days, "time_window_days": time_window.days,
'last_processed_timestamp': current_time.isoformat(), "last_processed_timestamp": current_time.isoformat(),
'cutoff_timestamp': cutoff_time.isoformat() "cutoff_timestamp": cutoff_time.isoformat(),
} }
async def add_frequency_weights( async def add_frequency_weights(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any]
usage_frequencies: Dict[str, Any]
) -> None: ) -> None:
""" """
Add frequency weights to graph nodes and edges using the graph adapter. Add frequency weights to graph nodes and edges using the graph adapter.
Uses direct Cypher queries for Neo4j adapter compatibility. Uses direct Cypher queries for Neo4j adapter compatibility.
Writes frequency_weight properties back to the graph for use in: Writes frequency_weight properties back to the graph for use in:
- Ranking frequently referenced entities higher during retrieval - Ranking frequently referenced entities higher during retrieval
- Adjusting scoring for completion strategies - Adjusting scoring for completion strategies
- Exposing usage metrics in dashboards or audits - Exposing usage metrics in dashboards or audits
:param graph_adapter: Graph database adapter interface :param graph_adapter: Graph database adapter interface
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency :param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
""" """
node_frequencies = usage_frequencies.get('node_frequencies', {}) node_frequencies = usage_frequencies.get("node_frequencies", {})
edge_frequencies = usage_frequencies.get('edge_frequencies', {}) edge_frequencies = usage_frequencies.get("edge_frequencies", {})
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes") logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
# Check adapter type and use appropriate method # Check adapter type and use appropriate method
adapter_type = type(graph_adapter).__name__ adapter_type = type(graph_adapter).__name__
logger.info(f"Using adapter: {adapter_type}") logger.info(f"Using adapter: {adapter_type}")
nodes_updated = 0 nodes_updated = 0
nodes_failed = 0 nodes_failed = 0
# Determine which method to use based on adapter type # Determine which method to use based on adapter type
use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query') use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query")
use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query') use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query")
use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties') use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr(
graph_adapter, "update_node_properties"
)
# Method 1: Neo4j Cypher with SET (creates properties on the fly) # Method 1: Neo4j Cypher with SET (creates properties on the fly)
if use_neo4j_cypher: if use_neo4j_cypher:
try: try:
logger.info("Using Neo4j Cypher SET method") logger.info("Using Neo4j Cypher SET method")
last_updated = usage_frequencies.get('last_processed_timestamp') last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items(): for node_id, frequency in node_frequencies.items():
try: try:
query = """ query = """
@ -246,47 +259,49 @@ async def add_frequency_weights(
n.frequency_updated_at = $updated_at n.frequency_updated_at = $updated_at
RETURN n.id as id RETURN n.id as id
""" """
result = await graph_adapter.query( result = await graph_adapter.query(
query, query,
params={ params={
'node_id': node_id, "node_id": node_id,
'frequency': frequency, "frequency": frequency,
'updated_at': last_updated "updated_at": last_updated,
} },
) )
if result and len(result) > 0: if result and len(result) > 0:
nodes_updated += 1 nodes_updated += 1
else: else:
logger.warning(f"Node {node_id} not found or not updated") logger.warning(f"Node {node_id} not found or not updated")
nodes_failed += 1 nodes_failed += 1
except Exception as e: except Exception as e:
logger.error(f"Error updating node {node_id}: {e}") logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1 nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
except Exception as e: except Exception as e:
logger.error(f"Neo4j Cypher update failed: {e}") logger.error(f"Neo4j Cypher update failed: {e}")
use_neo4j_cypher = False use_neo4j_cypher = False
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID) # Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'): elif (
use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node")
):
logger.info("Using Kuzu get_node + add_node method") logger.info("Using Kuzu get_node + add_node method")
last_updated = usage_frequencies.get('last_processed_timestamp') last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items(): for node_id, frequency in node_frequencies.items():
try: try:
# Get the existing node (returns a dict) # Get the existing node (returns a dict)
existing_node_dict = await graph_adapter.get_node(node_id) existing_node_dict = await graph_adapter.get_node(node_id)
if existing_node_dict: if existing_node_dict:
# Update the dict with new properties # Update the dict with new properties
existing_node_dict['frequency_weight'] = frequency existing_node_dict["frequency_weight"] = frequency
existing_node_dict['frequency_updated_at'] = last_updated existing_node_dict["frequency_updated_at"] = last_updated
# Kuzu's add_node likely just takes the dict directly, not a Node object # Kuzu's add_node likely just takes the dict directly, not a Node object
# Try passing the dict directly first # Try passing the dict directly first
try: try:
@ -295,20 +310,21 @@ async def add_frequency_weights(
except Exception as dict_error: except Exception as dict_error:
# If dict doesn't work, try creating a Node object # If dict doesn't work, try creating a Node object
logger.debug(f"Dict add failed, trying Node object: {dict_error}") logger.debug(f"Dict add failed, trying Node object: {dict_error}")
try: try:
from cognee.infrastructure.engine import Node from cognee.infrastructure.engine import Node
# Try different Node constructor patterns # Try different Node constructor patterns
try: try:
# Pattern 1: Just properties # Pattern 1: Just properties
node_obj = Node(existing_node_dict) node_obj = Node(existing_node_dict)
except: except Exception:
# Pattern 2: Type and properties # Pattern 2: Type and properties
node_obj = Node( node_obj = Node(
type=existing_node_dict.get('type', 'Unknown'), type=existing_node_dict.get("type", "Unknown"),
**existing_node_dict **existing_node_dict,
) )
await graph_adapter.add_node(node_obj) await graph_adapter.add_node(node_obj)
nodes_updated += 1 nodes_updated += 1
except Exception as node_error: except Exception as node_error:
@ -317,13 +333,13 @@ async def add_frequency_weights(
else: else:
logger.warning(f"Node {node_id} not found in graph") logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1 nodes_failed += 1
except Exception as e: except Exception as e:
logger.error(f"Error updating node {node_id}: {e}") logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1 nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
# Method 3: Generic get_node_by_id + update_node_properties # Method 3: Generic get_node_by_id + update_node_properties
elif use_get_update: elif use_get_update:
logger.info("Using get/update method for adapter") logger.info("Using get/update method for adapter")
@ -331,90 +347,95 @@ async def add_frequency_weights(
try: try:
# Get current node data # Get current node data
node_data = await graph_adapter.get_node_by_id(node_id) node_data = await graph_adapter.get_node_by_id(node_id)
if node_data: if node_data:
# Tweak the properties dict - add frequency_weight # Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict): if isinstance(node_data, dict):
properties = node_data.get('properties', {}) properties = node_data.get("properties", {})
else: else:
properties = getattr(node_data, 'properties', {}) or {} properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight # Update with frequency weight
properties['frequency_weight'] = frequency properties["frequency_weight"] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter # Write back via adapter
await graph_adapter.update_node_properties(node_id, properties) await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1 nodes_updated += 1
else: else:
logger.warning(f"Node {node_id} not found in graph") logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1 nodes_failed += 1
except Exception as e: except Exception as e:
logger.error(f"Error updating node {node_id}: {e}") logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1 nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
for node_id, frequency in node_frequencies.items(): for node_id, frequency in node_frequencies.items():
try: try:
# Get current node data # Get current node data
node_data = await graph_adapter.get_node_by_id(node_id) node_data = await graph_adapter.get_node_by_id(node_id)
if node_data: if node_data:
# Tweak the properties dict - add frequency_weight # Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict): if isinstance(node_data, dict):
properties = node_data.get('properties', {}) properties = node_data.get("properties", {})
else: else:
properties = getattr(node_data, 'properties', {}) or {} properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight # Update with frequency weight
properties['frequency_weight'] = frequency properties["frequency_weight"] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter # Write back via adapter
await graph_adapter.update_node_properties(node_id, properties) await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1 nodes_updated += 1
else: else:
logger.warning(f"Node {node_id} not found in graph") logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1 nodes_failed += 1
except Exception as e: except Exception as e:
logger.error(f"Error updating node {node_id}: {e}") logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1 nodes_failed += 1
# If no method is available # If no method is available
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update: if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
logger.error(f"Adapter {adapter_type} does not support required update methods") logger.error(f"Adapter {adapter_type} does not support required update methods")
logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'") logger.error(
"Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'"
)
return return
# Update edge frequencies # Update edge frequencies
# Note: Edge property updates are backend-specific # Note: Edge property updates are backend-specific
if edge_frequencies: if edge_frequencies:
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries") logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
edges_updated = 0 edges_updated = 0
edges_failed = 0 edges_failed = 0
for edge_key, frequency in edge_frequencies.items(): for edge_key, frequency in edge_frequencies.items():
try: try:
# Parse edge key: "relationship_type:source_id:target_id" # Parse edge key: "relationship_type:source_id:target_id"
parts = edge_key.split(':', 2) parts = edge_key.split(":", 2)
if len(parts) == 3: if len(parts) == 3:
relationship_type, source_id, target_id = parts relationship_type, source_id, target_id = parts
# Try to update edge if adapter supports it # Try to update edge if adapter supports it
if hasattr(graph_adapter, 'update_edge_properties'): if hasattr(graph_adapter, "update_edge_properties"):
edge_properties = { edge_properties = {
'frequency_weight': frequency, "frequency_weight": frequency,
'frequency_updated_at': usage_frequencies.get('last_processed_timestamp') "frequency_updated_at": usage_frequencies.get(
"last_processed_timestamp"
),
} }
await graph_adapter.update_edge_properties( await graph_adapter.update_edge_properties(
source_id, source_id, target_id, relationship_type, edge_properties
target_id,
relationship_type,
edge_properties
) )
edges_updated += 1 edges_updated += 1
else: else:
@ -423,28 +444,28 @@ async def add_frequency_weights(
f"Adapter doesn't support update_edge_properties for " f"Adapter doesn't support update_edge_properties for "
f"{relationship_type} ({source_id} -> {target_id})" f"{relationship_type} ({source_id} -> {target_id})"
) )
except Exception as e: except Exception as e:
logger.error(f"Error updating edge {edge_key}: {e}") logger.error(f"Error updating edge {edge_key}: {e}")
edges_failed += 1 edges_failed += 1
if edges_updated > 0: if edges_updated > 0:
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed") logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
else: else:
logger.info( logger.info(
"Edge frequency updates skipped (adapter may not support edge property updates)" "Edge frequency updates skipped (adapter may not support edge property updates)"
) )
# Store aggregate statistics as metadata if supported # Store aggregate statistics as metadata if supported
if hasattr(graph_adapter, 'set_metadata'): if hasattr(graph_adapter, "set_metadata"):
try: try:
metadata = { metadata = {
'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}), "element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}),
'total_interactions': usage_frequencies.get('total_interactions', 0), "total_interactions": usage_frequencies.get("total_interactions", 0),
'interactions_in_window': usage_frequencies.get('interactions_in_window', 0), "interactions_in_window": usage_frequencies.get("interactions_in_window", 0),
'last_frequency_update': usage_frequencies.get('last_processed_timestamp') "last_frequency_update": usage_frequencies.get("last_processed_timestamp"),
} }
await graph_adapter.set_metadata('usage_frequency_stats', metadata) await graph_adapter.set_metadata("usage_frequency_stats", metadata)
logger.info("Stored usage frequency statistics as metadata") logger.info("Stored usage frequency statistics as metadata")
except Exception as e: except Exception as e:
logger.warning(f"Could not store usage statistics as metadata: {e}") logger.warning(f"Could not store usage statistics as metadata: {e}")
@ -454,25 +475,25 @@ async def create_usage_frequency_pipeline(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface,
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1, min_interaction_threshold: int = 1,
batch_size: int = 100 batch_size: int = 100,
) -> tuple: ) -> tuple:
""" """
Create memify pipeline entry for usage frequency tracking. Create memify pipeline entry for usage frequency tracking.
This follows the same pattern as feedback enrichment flows, allowing This follows the same pattern as feedback enrichment flows, allowing
the frequency update to run end-to-end in a custom memify pipeline. the frequency update to run end-to-end in a custom memify pipeline.
Use case example: Use case example:
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline( extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=my_adapter, graph_adapter=my_adapter,
time_window=timedelta(days=30), time_window=timedelta(days=30),
min_interaction_threshold=2 min_interaction_threshold=2
) )
# Run in memify pipeline # Run in memify pipeline
pipeline = Pipeline(extraction_tasks + enrichment_tasks) pipeline = Pipeline(extraction_tasks + enrichment_tasks)
results = await pipeline.run() results = await pipeline.run()
:param graph_adapter: Graph database adapter :param graph_adapter: Graph database adapter
:param time_window: Time window for counting interactions (default: 7 days) :param time_window: Time window for counting interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1) :param min_interaction_threshold: Minimum interactions to track (default: 1)
@ -481,23 +502,23 @@ async def create_usage_frequency_pipeline(
""" """
logger.info("Creating usage frequency pipeline") logger.info("Creating usage frequency pipeline")
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}") logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
extraction_tasks = [ extraction_tasks = [
Task( Task(
extract_usage_frequency, extract_usage_frequency,
time_window=time_window, time_window=time_window,
min_interaction_threshold=min_interaction_threshold min_interaction_threshold=min_interaction_threshold,
) )
] ]
enrichment_tasks = [ enrichment_tasks = [
Task( Task(
add_frequency_weights, add_frequency_weights,
graph_adapter=graph_adapter, graph_adapter=graph_adapter,
task_config={"batch_size": batch_size} task_config={"batch_size": batch_size},
) )
] ]
return extraction_tasks, enrichment_tasks return extraction_tasks, enrichment_tasks
@ -505,21 +526,21 @@ async def run_usage_frequency_update(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface,
subgraphs: List[CogneeGraph], subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1 min_interaction_threshold: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Convenience function to run the complete usage frequency update pipeline. Convenience function to run the complete usage frequency update pipeline.
This is the main entry point for updating frequency weights on graph elements This is the main entry point for updating frequency weights on graph elements
based on CogneeUserInteraction data from cognee.search(save_interaction=True). based on CogneeUserInteraction data from cognee.search(save_interaction=True).
Example usage: Example usage:
# After running searches with save_interaction=True # After running searches with save_interaction=True
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
# Get the graph with interactions # Get the graph with interactions
graph = await get_cognee_graph_with_interactions() graph = await get_cognee_graph_with_interactions()
# Update frequency weights # Update frequency weights
stats = await run_usage_frequency_update( stats = await run_usage_frequency_update(
graph_adapter=graph_adapter, graph_adapter=graph_adapter,
@ -527,9 +548,9 @@ async def run_usage_frequency_update(
time_window=timedelta(days=30), # Last 30 days time_window=timedelta(days=30), # Last 30 days
min_interaction_threshold=2 # At least 2 uses min_interaction_threshold=2 # At least 2 uses
) )
print(f"Updated {len(stats['node_frequencies'])} nodes") print(f"Updated {len(stats['node_frequencies'])} nodes")
:param graph_adapter: Graph database adapter :param graph_adapter: Graph database adapter
:param subgraphs: List of CogneeGraph instances with interaction data :param subgraphs: List of CogneeGraph instances with interaction data
:param time_window: Time window for counting interactions :param time_window: Time window for counting interactions
@ -537,51 +558,48 @@ async def run_usage_frequency_update(
:return: Usage frequency statistics :return: Usage frequency statistics
""" """
logger.info("Starting usage frequency update") logger.info("Starting usage frequency update")
try: try:
# Extract frequencies from interaction data # Extract frequencies from interaction data
usage_frequencies = await extract_usage_frequency( usage_frequencies = await extract_usage_frequency(
subgraphs=subgraphs, subgraphs=subgraphs,
time_window=time_window, time_window=time_window,
min_interaction_threshold=min_interaction_threshold min_interaction_threshold=min_interaction_threshold,
) )
# Add frequency weights back to the graph # Add frequency weights back to the graph
await add_frequency_weights( await add_frequency_weights(
graph_adapter=graph_adapter, graph_adapter=graph_adapter, usage_frequencies=usage_frequencies
usage_frequencies=usage_frequencies
) )
logger.info("Usage frequency update completed successfully") logger.info("Usage frequency update completed successfully")
logger.info( logger.info(
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, " f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
f"{len(usage_frequencies['node_frequencies'])} nodes weighted" f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
) )
return usage_frequencies return usage_frequencies
except Exception as e: except Exception as e:
logger.error(f"Error during usage frequency update: {str(e)}") logger.error(f"Error during usage frequency update: {str(e)}")
raise raise
async def get_most_frequent_elements( async def get_most_frequent_elements(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None
top_n: int = 10,
element_type: Optional[str] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Retrieve the most frequently accessed graph elements. Retrieve the most frequently accessed graph elements.
Useful for analytics dashboards and understanding user behavior. Useful for analytics dashboards and understanding user behavior.
:param graph_adapter: Graph database adapter :param graph_adapter: Graph database adapter
:param top_n: Number of top elements to return :param top_n: Number of top elements to return
:param element_type: Optional filter by element type :param element_type: Optional filter by element type
:return: List of elements with their frequency weights :return: List of elements with their frequency weights
""" """
logger.info(f"Retrieving top {top_n} most frequent elements") logger.info(f"Retrieving top {top_n} most frequent elements")
# This would need to be implemented based on the specific graph adapter's query capabilities # This would need to be implemented based on the specific graph adapter's query capabilities
# Pseudocode: # Pseudocode:
# results = await graph_adapter.query_nodes_by_property( # results = await graph_adapter.query_nodes_by_property(
@ -590,6 +608,6 @@ async def get_most_frequent_elements(
# limit=top_n, # limit=top_n,
# filters={'type': element_type} if element_type else None # filters={'type': element_type} if element_type else None
# ) # )
logger.warning("get_most_frequent_elements needs adapter-specific implementation") logger.warning("get_most_frequent_elements needs adapter-specific implementation")
return [] return []

View file

@ -6,7 +6,7 @@ Tests cover extraction logic, adapter integration, edge cases, and end-to-end wo
Run with: Run with:
pytest test_usage_frequency_comprehensive.py -v pytest test_usage_frequency_comprehensive.py -v
Or without pytest: Or without pytest:
python test_usage_frequency_comprehensive.py python test_usage_frequency_comprehensive.py
""" """
@ -23,8 +23,9 @@ try:
from cognee.tasks.memify.extract_usage_frequency import ( from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency, extract_usage_frequency,
add_frequency_weights, add_frequency_weights,
run_usage_frequency_update run_usage_frequency_update,
) )
COGNEE_AVAILABLE = True COGNEE_AVAILABLE = True
except ImportError: except ImportError:
COGNEE_AVAILABLE = False COGNEE_AVAILABLE = False
@ -33,16 +34,16 @@ except ImportError:
class TestUsageFrequencyExtraction(unittest.TestCase): class TestUsageFrequencyExtraction(unittest.TestCase):
"""Test the core frequency extraction logic.""" """Test the core frequency extraction logic."""
def setUp(self): def setUp(self):
"""Set up test fixtures.""" """Set up test fixtures."""
if not COGNEE_AVAILABLE: if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available") self.skipTest("Cognee modules not available")
def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5): def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5):
"""Create a mock graph with interactions and elements.""" """Create a mock graph with interactions and elements."""
graph = CogneeGraph() graph = CogneeGraph()
# Create interaction nodes # Create interaction nodes
current_time = datetime.now() current_time = datetime.now()
for i in range(num_interactions): for i in range(num_interactions):
@ -50,25 +51,22 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id=f"interaction_{i}", id=f"interaction_{i}",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'query_text': f'Test query {i}', "query_text": f"Test query {i}",
'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000) "timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000),
} },
) )
graph.add_node(interaction_node) graph.add_node(interaction_node)
# Create graph element nodes # Create graph element nodes
for i in range(num_elements): for i in range(num_elements):
element_node = Node( element_node = Node(
id=f"element_{i}", id=f"element_{i}",
node_type="DocumentChunk", node_type="DocumentChunk",
attributes={ attributes={"type": "DocumentChunk", "text": f"Element content {i}"},
'type': 'DocumentChunk',
'text': f'Element content {i}'
}
) )
graph.add_node(element_node) graph.add_node(element_node)
# Create usage edges (interactions reference elements) # Create usage edges (interactions reference elements)
for i in range(num_interactions): for i in range(num_interactions):
# Each interaction uses 2-3 elements # Each interaction uses 2-3 elements
@ -78,183 +76,179 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
node1=graph.get_node(f"interaction_{i}"), node1=graph.get_node(f"interaction_{i}"),
node2=graph.get_node(f"element_{element_idx}"), node2=graph.get_node(f"element_{element_idx}"),
edge_type="used_graph_element_to_answer", edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'} attributes={"relationship_type": "used_graph_element_to_answer"},
) )
graph.add_edge(edge) graph.add_edge(edge)
return graph return graph
async def test_basic_frequency_extraction(self): async def test_basic_frequency_extraction(self):
"""Test basic frequency extraction with simple graph.""" """Test basic frequency extraction with simple graph."""
graph = self.create_mock_graph(num_interactions=3, num_elements=5) graph = self.create_mock_graph(num_interactions=3, num_elements=5)
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
time_window=timedelta(days=7),
min_interaction_threshold=1
) )
self.assertIn('node_frequencies', result) self.assertIn("node_frequencies", result)
self.assertIn('total_interactions', result) self.assertIn("total_interactions", result)
self.assertEqual(result['total_interactions'], 3) self.assertEqual(result["total_interactions"], 3)
self.assertGreater(len(result['node_frequencies']), 0) self.assertGreater(len(result["node_frequencies"]), 0)
async def test_time_window_filtering(self): async def test_time_window_filtering(self):
"""Test that time window correctly filters old interactions.""" """Test that time window correctly filters old interactions."""
graph = CogneeGraph() graph = CogneeGraph()
current_time = datetime.now() current_time = datetime.now()
# Add recent interaction (within window) # Add recent interaction (within window)
recent_node = Node( recent_node = Node(
id="recent_interaction", id="recent_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(current_time.timestamp() * 1000) "timestamp": int(current_time.timestamp() * 1000),
} },
) )
graph.add_node(recent_node) graph.add_node(recent_node)
# Add old interaction (outside window) # Add old interaction (outside window)
old_node = Node( old_node = Node(
id="old_interaction", id="old_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000) "timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000),
} },
) )
graph.add_node(old_node) graph.add_node(old_node)
# Add element # Add element
element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) element = Node(
id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}
)
graph.add_node(element) graph.add_node(element)
# Add edges # Add edges
graph.add_edge(Edge( graph.add_edge(
node1=recent_node, node2=element, Edge(
edge_type="used_graph_element_to_answer", node1=recent_node,
attributes={'relationship_type': 'used_graph_element_to_answer'} node2=element,
)) edge_type="used_graph_element_to_answer",
graph.add_edge(Edge( attributes={"relationship_type": "used_graph_element_to_answer"},
node1=old_node, node2=element, )
edge_type="used_graph_element_to_answer", )
attributes={'relationship_type': 'used_graph_element_to_answer'} graph.add_edge(
)) Edge(
node1=old_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
# Extract with 7-day window # Extract with 7-day window
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
time_window=timedelta(days=7),
min_interaction_threshold=1
) )
# Should only count recent interaction # Should only count recent interaction
self.assertEqual(result['interactions_in_window'], 1) self.assertEqual(result["interactions_in_window"], 1)
self.assertEqual(result['total_interactions'], 2) self.assertEqual(result["total_interactions"], 2)
async def test_threshold_filtering(self): async def test_threshold_filtering(self):
"""Test that minimum threshold filters low-frequency nodes.""" """Test that minimum threshold filters low-frequency nodes."""
graph = self.create_mock_graph(num_interactions=5, num_elements=10) graph = self.create_mock_graph(num_interactions=5, num_elements=10)
# Extract with threshold of 3 # Extract with threshold of 3
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3
time_window=timedelta(days=7),
min_interaction_threshold=3
) )
# Only nodes with 3+ accesses should be included # Only nodes with 3+ accesses should be included
for node_id, freq in result['node_frequencies'].items(): for node_id, freq in result["node_frequencies"].items():
self.assertGreaterEqual(freq, 3) self.assertGreaterEqual(freq, 3)
async def test_element_type_tracking(self): async def test_element_type_tracking(self):
"""Test that element types are properly tracked.""" """Test that element types are properly tracked."""
graph = CogneeGraph() graph = CogneeGraph()
# Create interaction # Create interaction
interaction = Node( interaction = Node(
id="interaction_1", id="interaction_1",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(datetime.now().timestamp() * 1000) "timestamp": int(datetime.now().timestamp() * 1000),
} },
) )
graph.add_node(interaction) graph.add_node(interaction)
# Create elements of different types # Create elements of different types
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"})
entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'}) entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"})
graph.add_node(chunk) graph.add_node(chunk)
graph.add_node(entity) graph.add_node(entity)
# Add edges # Add edges
for element in [chunk, entity]: for element in [chunk, entity]:
graph.add_edge(Edge( graph.add_edge(
node1=interaction, node2=element, Edge(
edge_type="used_graph_element_to_answer", node1=interaction,
attributes={'relationship_type': 'used_graph_element_to_answer'} node2=element,
)) edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
result = await extract_usage_frequency( )
subgraphs=[graph], )
time_window=timedelta(days=7)
) result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
# Check element types were tracked # Check element types were tracked
self.assertIn('element_type_frequencies', result) self.assertIn("element_type_frequencies", result)
types = result['element_type_frequencies'] types = result["element_type_frequencies"]
self.assertIn('DocumentChunk', types) self.assertIn("DocumentChunk", types)
self.assertIn('Entity', types) self.assertIn("Entity", types)
async def test_empty_graph(self): async def test_empty_graph(self):
"""Test handling of empty graph.""" """Test handling of empty graph."""
graph = CogneeGraph() graph = CogneeGraph()
result = await extract_usage_frequency( result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
subgraphs=[graph],
time_window=timedelta(days=7) self.assertEqual(result["total_interactions"], 0)
) self.assertEqual(len(result["node_frequencies"]), 0)
self.assertEqual(result['total_interactions'], 0)
self.assertEqual(len(result['node_frequencies']), 0)
async def test_no_interactions_in_window(self): async def test_no_interactions_in_window(self):
"""Test handling when all interactions are outside time window.""" """Test handling when all interactions are outside time window."""
graph = CogneeGraph() graph = CogneeGraph()
# Add old interaction # Add old interaction
old_time = datetime.now() - timedelta(days=30) old_time = datetime.now() - timedelta(days=30)
old_interaction = Node( old_interaction = Node(
id="old_interaction", id="old_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(old_time.timestamp() * 1000) "timestamp": int(old_time.timestamp() * 1000),
} },
) )
graph.add_node(old_interaction) graph.add_node(old_interaction)
result = await extract_usage_frequency( result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
subgraphs=[graph],
time_window=timedelta(days=7) self.assertEqual(result["interactions_in_window"], 0)
) self.assertEqual(result["total_interactions"], 1)
self.assertEqual(result['interactions_in_window'], 0)
self.assertEqual(result['total_interactions'], 1)
class TestIntegration(unittest.TestCase): class TestIntegration(unittest.TestCase):
"""Integration tests for the complete workflow.""" """Integration tests for the complete workflow."""
def setUp(self): def setUp(self):
"""Set up test fixtures.""" """Set up test fixtures."""
if not COGNEE_AVAILABLE: if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available") self.skipTest("Cognee modules not available")
async def test_end_to_end_workflow(self): async def test_end_to_end_workflow(self):
"""Test the complete end-to-end frequency tracking workflow.""" """Test the complete end-to-end frequency tracking workflow."""
# This would require a full Cognee setup with database # This would require a full Cognee setup with database
@ -266,6 +260,7 @@ class TestIntegration(unittest.TestCase):
# Test Runner # Test Runner
# ============================================================================ # ============================================================================
def run_async_test(test_func): def run_async_test(test_func):
"""Helper to run async test functions.""" """Helper to run async test functions."""
asyncio.run(test_func()) asyncio.run(test_func())
@ -277,24 +272,24 @@ def main():
print("⚠ Cognee not available - skipping tests") print("⚠ Cognee not available - skipping tests")
print("Install with: pip install cognee[neo4j]") print("Install with: pip install cognee[neo4j]")
return return
print("=" * 80) print("=" * 80)
print("Running Usage Frequency Tests") print("Running Usage Frequency Tests")
print("=" * 80) print("=" * 80)
print() print()
# Create test suite # Create test suite
loader = unittest.TestLoader() loader = unittest.TestLoader()
suite = unittest.TestSuite() suite = unittest.TestSuite()
# Add tests # Add tests
suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction)) suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction))
suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
# Run tests # Run tests
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite) result = runner.run(suite)
# Summary # Summary
print() print()
print("=" * 80) print("=" * 80)
@ -305,9 +300,9 @@ def main():
print(f"Failures: {len(result.failures)}") print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}") print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped)}") print(f"Skipped: {len(result.skipped)}")
return 0 if result.wasSuccessful() else 1 return 0 if result.wasSuccessful() else 1
if __name__ == "__main__": if __name__ == "__main__":
exit(main()) exit(main())

View file

@ -39,10 +39,11 @@ load_dotenv()
# STEP 1: Setup and Configuration # STEP 1: Setup and Configuration
# ============================================================================ # ============================================================================
async def setup_knowledge_base(): async def setup_knowledge_base():
""" """
Create a fresh knowledge base with sample content. Create a fresh knowledge base with sample content.
In a real application, you would: In a real application, you would:
- Load documents from files, databases, or APIs - Load documents from files, databases, or APIs
- Process larger datasets - Process larger datasets
@ -51,13 +52,13 @@ async def setup_knowledge_base():
print("=" * 80) print("=" * 80)
print("STEP 1: Setting up knowledge base") print("STEP 1: Setting up knowledge base")
print("=" * 80) print("=" * 80)
# Reset state for clean demo (optional in production) # Reset state for clean demo (optional in production)
print("\nResetting Cognee state...") print("\nResetting Cognee state...")
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
print("✓ Reset complete") print("✓ Reset complete")
# Sample content: AI/ML educational material # Sample content: AI/ML educational material
documents = [ documents = [
""" """
@ -87,16 +88,16 @@ async def setup_knowledge_base():
recognition, object detection, and image segmentation tasks. recognition, object detection, and image segmentation tasks.
""", """,
] ]
print(f"\nAdding {len(documents)} documents to knowledge base...") print(f"\nAdding {len(documents)} documents to knowledge base...")
await cognee.add(documents, dataset_name="ai_ml_fundamentals") await cognee.add(documents, dataset_name="ai_ml_fundamentals")
print("✓ Documents added") print("✓ Documents added")
# Build knowledge graph # Build knowledge graph
print("\nBuilding knowledge graph (cognify)...") print("\nBuilding knowledge graph (cognify)...")
await cognee.cognify() await cognee.cognify()
print("✓ Knowledge graph built") print("✓ Knowledge graph built")
print("\n" + "=" * 80) print("\n" + "=" * 80)
@ -104,26 +105,27 @@ async def setup_knowledge_base():
# STEP 2: Simulate User Searches with Interaction Tracking # STEP 2: Simulate User Searches with Interaction Tracking
# ============================================================================ # ============================================================================
async def simulate_user_searches(queries: List[str]): async def simulate_user_searches(queries: List[str]):
""" """
Simulate users searching the knowledge base. Simulate users searching the knowledge base.
The key parameter is save_interaction=True, which creates: The key parameter is save_interaction=True, which creates:
- CogneeUserInteraction nodes (one per search) - CogneeUserInteraction nodes (one per search)
- used_graph_element_to_answer edges (connecting queries to relevant nodes) - used_graph_element_to_answer edges (connecting queries to relevant nodes)
Args: Args:
queries: List of search queries to simulate queries: List of search queries to simulate
Returns: Returns:
Number of successful searches Number of successful searches
""" """
print("=" * 80) print("=" * 80)
print("STEP 2: Simulating user searches with interaction tracking") print("STEP 2: Simulating user searches with interaction tracking")
print("=" * 80) print("=" * 80)
successful_searches = 0 successful_searches = 0
for i, query in enumerate(queries, 1): for i, query in enumerate(queries, 1):
print(f"\nSearch {i}/{len(queries)}: '{query}'") print(f"\nSearch {i}/{len(queries)}: '{query}'")
try: try:
@ -131,20 +133,20 @@ async def simulate_user_searches(queries: List[str]):
query_type=SearchType.GRAPH_COMPLETION, query_type=SearchType.GRAPH_COMPLETION,
query_text=query, query_text=query,
save_interaction=True, # ← THIS IS CRITICAL! save_interaction=True, # ← THIS IS CRITICAL!
top_k=5 top_k=5,
) )
successful_searches += 1 successful_searches += 1
# Show snippet of results # Show snippet of results
result_preview = str(results)[:100] if results else "No results" result_preview = str(results)[:100] if results else "No results"
print(f" ✓ Completed ({result_preview}...)") print(f" ✓ Completed ({result_preview}...)")
except Exception as e: except Exception as e:
print(f" ✗ Failed: {e}") print(f" ✗ Failed: {e}")
print(f"\n✓ Completed {successful_searches}/{len(queries)} searches") print(f"\n✓ Completed {successful_searches}/{len(queries)} searches")
print("=" * 80) print("=" * 80)
return successful_searches return successful_searches
@ -152,71 +154,80 @@ async def simulate_user_searches(queries: List[str]):
# STEP 3: Extract and Apply Usage Frequencies # STEP 3: Extract and Apply Usage Frequencies
# ============================================================================ # ============================================================================
async def extract_and_apply_frequencies( async def extract_and_apply_frequencies(
time_window_days: int = 7, time_window_days: int = 7, min_threshold: int = 1
min_threshold: int = 1
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Extract usage frequencies from interactions and apply them to the graph. Extract usage frequencies from interactions and apply them to the graph.
This function: This function:
1. Retrieves the graph with interaction data 1. Retrieves the graph with interaction data
2. Counts how often each node was accessed 2. Counts how often each node was accessed
3. Writes frequency_weight property back to nodes 3. Writes frequency_weight property back to nodes
Args: Args:
time_window_days: Only count interactions from last N days time_window_days: Only count interactions from last N days
min_threshold: Minimum accesses to track (filter out rarely used nodes) min_threshold: Minimum accesses to track (filter out rarely used nodes)
Returns: Returns:
Dictionary with statistics about the frequency update Dictionary with statistics about the frequency update
""" """
print("=" * 80) print("=" * 80)
print("STEP 3: Extracting and applying usage frequencies") print("STEP 3: Extracting and applying usage frequencies")
print("=" * 80) print("=" * 80)
# Get graph adapter # Get graph adapter
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
# Retrieve graph with interactions # Retrieve graph with interactions
print("\nRetrieving graph from database...") print("\nRetrieving graph from database...")
graph = CogneeGraph() graph = CogneeGraph()
await graph.project_graph_from_db( await graph.project_graph_from_db(
adapter=graph_engine, adapter=graph_engine,
node_properties_to_project=[ node_properties_to_project=[
"type", "node_type", "timestamp", "created_at", "type",
"text", "name", "query_text", "frequency_weight" "node_type",
"timestamp",
"created_at",
"text",
"name",
"query_text",
"frequency_weight",
], ],
edge_properties_to_project=["relationship_type", "timestamp"], edge_properties_to_project=["relationship_type", "timestamp"],
directed=True, directed=True,
) )
print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges") print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
# Count interaction nodes # Count interaction nodes
interaction_nodes = [ interaction_nodes = [
n for n in graph.nodes.values() n
if n.attributes.get('type') == 'CogneeUserInteraction' or for n in graph.nodes.values()
n.attributes.get('node_type') == 'CogneeUserInteraction' if n.attributes.get("type") == "CogneeUserInteraction"
or n.attributes.get("node_type") == "CogneeUserInteraction"
] ]
print(f"✓ Found {len(interaction_nodes)} interaction nodes") print(f"✓ Found {len(interaction_nodes)} interaction nodes")
# Run frequency extraction and update # Run frequency extraction and update
print(f"\nExtracting frequencies (time window: {time_window_days} days)...") print(f"\nExtracting frequencies (time window: {time_window_days} days)...")
stats = await run_usage_frequency_update( stats = await run_usage_frequency_update(
graph_adapter=graph_engine, graph_adapter=graph_engine,
subgraphs=[graph], subgraphs=[graph],
time_window=timedelta(days=time_window_days), time_window=timedelta(days=time_window_days),
min_interaction_threshold=min_threshold min_interaction_threshold=min_threshold,
)
print("\n✓ Frequency extraction complete!")
print(
f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}"
) )
print(f"\n✓ Frequency extraction complete!")
print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}")
print(f" - Nodes weighted: {len(stats['node_frequencies'])}") print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}") print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
print("=" * 80) print("=" * 80)
return stats return stats
@ -224,33 +235,30 @@ async def extract_and_apply_frequencies(
# STEP 4: Analyze and Display Results # STEP 4: Analyze and Display Results
# ============================================================================ # ============================================================================
async def analyze_results(stats: Dict[str, Any]): async def analyze_results(stats: Dict[str, Any]):
""" """
Analyze and display the frequency tracking results. Analyze and display the frequency tracking results.
Shows: Shows:
- Top most frequently accessed nodes - Top most frequently accessed nodes
- Element type distribution - Element type distribution
- Verification that weights were written to database - Verification that weights were written to database
Args: Args:
stats: Statistics from frequency extraction stats: Statistics from frequency extraction
""" """
print("=" * 80) print("=" * 80)
print("STEP 4: Analyzing usage frequency results") print("STEP 4: Analyzing usage frequency results")
print("=" * 80) print("=" * 80)
# Display top nodes by frequency # Display top nodes by frequency
if stats['node_frequencies']: if stats["node_frequencies"]:
print("\n📊 Top 10 Most Frequently Accessed Elements:") print("\n📊 Top 10 Most Frequently Accessed Elements:")
print("-" * 80) print("-" * 80)
sorted_nodes = sorted( sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True)
stats['node_frequencies'].items(),
key=lambda x: x[1],
reverse=True
)
# Get graph to display node details # Get graph to display node details
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
graph = CogneeGraph() graph = CogneeGraph()
@ -260,48 +268,48 @@ async def analyze_results(stats: Dict[str, Any]):
edge_properties_to_project=[], edge_properties_to_project=[],
directed=True, directed=True,
) )
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1): for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
node = graph.get_node(node_id) node = graph.get_node(node_id)
if node: if node:
node_type = node.attributes.get('type', 'Unknown') node_type = node.attributes.get("type", "Unknown")
text = node.attributes.get('text') or node.attributes.get('name') or '' text = node.attributes.get("text") or node.attributes.get("name") or ""
text_preview = text[:60] + "..." if len(text) > 60 else text text_preview = text[:60] + "..." if len(text) > 60 else text
print(f"\n{i}. Frequency: {frequency} accesses") print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Type: {node_type}") print(f" Type: {node_type}")
print(f" Content: {text_preview}") print(f" Content: {text_preview}")
else: else:
print(f"\n{i}. Frequency: {frequency} accesses") print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Node ID: {node_id[:50]}...") print(f" Node ID: {node_id[:50]}...")
# Display element type distribution # Display element type distribution
if stats.get('element_type_frequencies'): if stats.get("element_type_frequencies"):
print("\n\n📈 Element Type Distribution:") print("\n\n📈 Element Type Distribution:")
print("-" * 80) print("-" * 80)
type_dist = stats['element_type_frequencies'] type_dist = stats["element_type_frequencies"]
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True): for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
print(f" {elem_type}: {count} accesses") print(f" {elem_type}: {count} accesses")
# Verify weights in database (Neo4j only) # Verify weights in database (Neo4j only)
print("\n\n🔍 Verifying weights in database...") print("\n\n🔍 Verifying weights in database...")
print("-" * 80) print("-" * 80)
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
adapter_type = type(graph_engine).__name__ adapter_type = type(graph_engine).__name__
if adapter_type == 'Neo4jAdapter': if adapter_type == "Neo4jAdapter":
try: try:
result = await graph_engine.query(""" result = await graph_engine.query("""
MATCH (n) MATCH (n)
WHERE n.frequency_weight IS NOT NULL WHERE n.frequency_weight IS NOT NULL
RETURN count(n) as weighted_count RETURN count(n) as weighted_count
""") """)
count = result[0]['weighted_count'] if result else 0 count = result[0]["weighted_count"] if result else 0
if count > 0: if count > 0:
print(f"{count} nodes have frequency_weight in Neo4j database") print(f"{count} nodes have frequency_weight in Neo4j database")
# Show sample # Show sample
sample = await graph_engine.query(""" sample = await graph_engine.query("""
MATCH (n) MATCH (n)
@ -310,7 +318,7 @@ async def analyze_results(stats: Dict[str, Any]):
ORDER BY n.frequency_weight DESC ORDER BY n.frequency_weight DESC
LIMIT 3 LIMIT 3
""") """)
print("\nSample weighted nodes:") print("\nSample weighted nodes:")
for row in sample: for row in sample:
print(f" - Weight: {row['weight']}, Type: {row['labels']}") print(f" - Weight: {row['weight']}, Type: {row['labels']}")
@ -320,7 +328,7 @@ async def analyze_results(stats: Dict[str, Any]):
print(f"Could not verify in Neo4j: {e}") print(f"Could not verify in Neo4j: {e}")
else: else:
print(f"Database verification not implemented for {adapter_type}") print(f"Database verification not implemented for {adapter_type}")
print("\n" + "=" * 80) print("\n" + "=" * 80)
@ -328,10 +336,11 @@ async def analyze_results(stats: Dict[str, Any]):
# STEP 5: Demonstrate Usage in Retrieval # STEP 5: Demonstrate Usage in Retrieval
# ============================================================================ # ============================================================================
async def demonstrate_retrieval_usage(): async def demonstrate_retrieval_usage():
""" """
Demonstrate how frequency weights can be used in retrieval. Demonstrate how frequency weights can be used in retrieval.
Note: This is a conceptual demonstration. To actually use frequency Note: This is a conceptual demonstration. To actually use frequency
weights in ranking, you would need to modify the retrieval/completion weights in ranking, you would need to modify the retrieval/completion
strategies to incorporate the frequency_weight property. strategies to incorporate the frequency_weight property.
@ -339,39 +348,39 @@ async def demonstrate_retrieval_usage():
print("=" * 80) print("=" * 80)
print("STEP 5: How to use frequency weights in retrieval") print("STEP 5: How to use frequency weights in retrieval")
print("=" * 80) print("=" * 80)
print(""" print("""
Frequency weights can be used to improve search results: Frequency weights can be used to improve search results:
1. RANKING BOOST: 1. RANKING BOOST:
- Multiply relevance scores by frequency_weight - Multiply relevance scores by frequency_weight
- Prioritize frequently accessed nodes in results - Prioritize frequently accessed nodes in results
2. COMPLETION STRATEGIES: 2. COMPLETION STRATEGIES:
- Adjust triplet importance based on usage - Adjust triplet importance based on usage
- Filter out rarely accessed information - Filter out rarely accessed information
3. ANALYTICS: 3. ANALYTICS:
- Track trending topics over time - Track trending topics over time
- Understand user interests and behavior - Understand user interests and behavior
- Identify knowledge gaps (low-frequency nodes) - Identify knowledge gaps (low-frequency nodes)
4. ADAPTIVE RETRIEVAL: 4. ADAPTIVE RETRIEVAL:
- Personalize results based on team usage patterns - Personalize results based on team usage patterns
- Surface popular answers faster - Surface popular answers faster
Example Cypher query with frequency boost (Neo4j): Example Cypher query with frequency boost (Neo4j):
MATCH (n) MATCH (n)
WHERE n.text CONTAINS $search_term WHERE n.text CONTAINS $search_term
RETURN n, n.frequency_weight as boost RETURN n, n.frequency_weight as boost
ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC
LIMIT 10 LIMIT 10
To integrate this into Cognee, you would modify the completion To integrate this into Cognee, you would modify the completion
strategy to include frequency_weight in the scoring function. strategy to include frequency_weight in the scoring function.
""") """)
print("=" * 80) print("=" * 80)
@ -379,6 +388,7 @@ async def demonstrate_retrieval_usage():
# MAIN: Run Complete Example # MAIN: Run Complete Example
# ============================================================================ # ============================================================================
async def main(): async def main():
""" """
Run the complete end-to-end usage frequency tracking example. Run the complete end-to-end usage frequency tracking example.
@ -390,25 +400,25 @@ async def main():
print("" + " " * 78 + "") print("" + " " * 78 + "")
print("" + "=" * 78 + "") print("" + "=" * 78 + "")
print("\n") print("\n")
# Configuration check # Configuration check
print("Configuration:") print("Configuration:")
print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}") print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}")
print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}") print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}")
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}") print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
# Verify LLM key is set # Verify LLM key is set
if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here': if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here":
print("\n⚠ WARNING: LLM_API_KEY not set in .env file") print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
print(" Set your API key to run searches") print(" Set your API key to run searches")
return return
print("\n") print("\n")
try: try:
# Step 1: Setup # Step 1: Setup
await setup_knowledge_base() await setup_knowledge_base()
# Step 2: Simulate searches # Step 2: Simulate searches
# Note: Repeat queries increase frequency for those topics # Note: Repeat queries increase frequency for those topics
queries = [ queries = [
@ -422,25 +432,22 @@ async def main():
"What is reinforcement learning?", "What is reinforcement learning?",
"Tell me more about neural networks", # Third repeat "Tell me more about neural networks", # Third repeat
] ]
successful_searches = await simulate_user_searches(queries) successful_searches = await simulate_user_searches(queries)
if successful_searches == 0: if successful_searches == 0:
print("⚠ No searches completed - cannot demonstrate frequency tracking") print("⚠ No searches completed - cannot demonstrate frequency tracking")
return return
# Step 3: Extract frequencies # Step 3: Extract frequencies
stats = await extract_and_apply_frequencies( stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1)
time_window_days=7,
min_threshold=1
)
# Step 4: Analyze results # Step 4: Analyze results
await analyze_results(stats) await analyze_results(stats)
# Step 5: Show usage examples # Step 5: Show usage examples
await demonstrate_retrieval_usage() await demonstrate_retrieval_usage()
# Summary # Summary
print("\n") print("\n")
print("" + "=" * 78 + "") print("" + "=" * 78 + "")
@ -449,26 +456,27 @@ async def main():
print("" + " " * 78 + "") print("" + " " * 78 + "")
print("" + "=" * 78 + "") print("" + "=" * 78 + "")
print("\n") print("\n")
print("Summary:") print("Summary:")
print(f" ✓ Documents added: 4") print(" ✓ Documents added: 4")
print(f" ✓ Searches performed: {successful_searches}") print(f" ✓ Searches performed: {successful_searches}")
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}") print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}") print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
print("\nNext steps:") print("\nNext steps:")
print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph") print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph")
print(" 2. Modify retrieval strategies to use frequency_weight") print(" 2. Modify retrieval strategies to use frequency_weight")
print(" 3. Build analytics dashboards using element_type_frequencies") print(" 3. Build analytics dashboards using element_type_frequencies")
print(" 4. Run periodic frequency updates to track trends over time") print(" 4. Run periodic frequency updates to track trends over time")
print("\n") print("\n")
except Exception as e: except Exception as e:
print(f"\n✗ Example failed: {e}") print(f"\n✗ Example failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())