diff --git a/cognee/tasks/memify/extract_usage_frequency.py b/cognee/tasks/memify/extract_usage_frequency.py new file mode 100644 index 000000000..7e437bd18 --- /dev/null +++ b/cognee/tasks/memify/extract_usage_frequency.py @@ -0,0 +1,595 @@ +# cognee/tasks/memify/extract_usage_frequency.py +from typing import List, Dict, Any, Optional +from datetime import datetime, timedelta +from cognee.shared.logging_utils import get_logger +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.pipelines.tasks.task import Task +from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface + +logger = get_logger("extract_usage_frequency") + + +async def extract_usage_frequency( + subgraphs: List[CogneeGraph], + time_window: timedelta = timedelta(days=7), + min_interaction_threshold: int = 1 +) -> Dict[str, Any]: + """ + Extract usage frequency from CogneeUserInteraction nodes. + + When save_interaction=True in cognee.search(), the system creates: + - CogneeUserInteraction nodes (representing the query/answer interaction) + - used_graph_element_to_answer edges (connecting interactions to graph elements used) + + This function tallies how often each graph element is referenced via these edges, + enabling frequency-based ranking in downstream retrievers. + + :param subgraphs: List of CogneeGraph instances containing interaction data + :param time_window: Time window to consider for interactions (default: 7 days) + :param min_interaction_threshold: Minimum interactions to track (default: 1) + :return: Dictionary containing node frequencies, edge frequencies, and metadata + """ + current_time = datetime.now() + cutoff_time = current_time - time_window + + # Track frequencies for graph elements (nodes and edges) + node_frequencies = {} + edge_frequencies = {} + relationship_type_frequencies = {} + + # Track interaction metadata + interaction_count = 0 + interactions_in_window = 0 + + logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs") + logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}") + + for subgraph in subgraphs: + # Find all CogneeUserInteraction nodes + interaction_nodes = {} + for node_id, node in subgraph.nodes.items(): + node_type = node.attributes.get('type') or node.attributes.get('node_type') + + if node_type == 'CogneeUserInteraction': + # Parse and validate timestamp + timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at') + if timestamp_value is not None: + try: + # Handle various timestamp formats + interaction_time = None + + if isinstance(timestamp_value, datetime): + # Already a Python datetime + interaction_time = timestamp_value + elif isinstance(timestamp_value, (int, float)): + # Unix timestamp (assume milliseconds if > 10 digits) + if timestamp_value > 10000000000: + # Milliseconds since epoch + interaction_time = datetime.fromtimestamp(timestamp_value / 1000.0) + else: + # Seconds since epoch + interaction_time = datetime.fromtimestamp(timestamp_value) + elif isinstance(timestamp_value, str): + # Try different string formats + if timestamp_value.isdigit(): + # Numeric string - treat as Unix timestamp + ts_int = int(timestamp_value) + if ts_int > 10000000000: + interaction_time = datetime.fromtimestamp(ts_int / 1000.0) + else: + interaction_time = datetime.fromtimestamp(ts_int) + else: + # ISO format string + interaction_time = datetime.fromisoformat(timestamp_value) + elif hasattr(timestamp_value, 'to_native'): + # Neo4j datetime object - convert to Python datetime + interaction_time = timestamp_value.to_native() + elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'): + # Datetime-like object - extract components + try: + interaction_time = datetime( + year=timestamp_value.year, + month=timestamp_value.month, + day=timestamp_value.day, + hour=getattr(timestamp_value, 'hour', 0), + minute=getattr(timestamp_value, 'minute', 0), + second=getattr(timestamp_value, 'second', 0), + microsecond=getattr(timestamp_value, 'microsecond', 0) + ) + except (AttributeError, ValueError): + pass + + if interaction_time is None: + # Last resort: try converting to string and parsing + str_value = str(timestamp_value) + if str_value.isdigit(): + ts_int = int(str_value) + if ts_int > 10000000000: + interaction_time = datetime.fromtimestamp(ts_int / 1000.0) + else: + interaction_time = datetime.fromtimestamp(ts_int) + else: + interaction_time = datetime.fromisoformat(str_value) + + if interaction_time is None: + raise ValueError(f"Could not parse timestamp: {timestamp_value}") + + # Make sure it's timezone-naive for comparison + if interaction_time.tzinfo is not None: + interaction_time = interaction_time.replace(tzinfo=None) + + interaction_nodes[node_id] = { + 'node': node, + 'timestamp': interaction_time, + 'in_window': interaction_time >= cutoff_time + } + interaction_count += 1 + if interaction_time >= cutoff_time: + interactions_in_window += 1 + except (ValueError, TypeError, AttributeError, OSError) as e: + logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}") + logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}") + + # Process edges to find graph elements used in interactions + for edge in subgraph.edges: + relationship_type = edge.attributes.get('relationship_type') + + # Look for 'used_graph_element_to_answer' edges + if relationship_type == 'used_graph_element_to_answer': + # node1 should be the CogneeUserInteraction, node2 is the graph element + source_id = str(edge.node1.id) + target_id = str(edge.node2.id) + + # Check if source is an interaction node in our time window + if source_id in interaction_nodes: + interaction_data = interaction_nodes[source_id] + + if interaction_data['in_window']: + # Count the graph element (target node) being used + node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1 + + # Also track what type of element it is for analytics + target_node = subgraph.get_node(target_id) + if target_node: + element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type') + if element_type: + relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1 + + # Also track general edge usage patterns + elif relationship_type and relationship_type != 'used_graph_element_to_answer': + # Check if either endpoint is referenced in a recent interaction + source_id = str(edge.node1.id) + target_id = str(edge.node2.id) + + # If this edge connects to any frequently accessed nodes, track the edge type + if source_id in node_frequencies or target_id in node_frequencies: + edge_key = f"{relationship_type}:{source_id}:{target_id}" + edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1 + + # Filter frequencies above threshold + filtered_node_frequencies = { + node_id: freq for node_id, freq in node_frequencies.items() + if freq >= min_interaction_threshold + } + + filtered_edge_frequencies = { + edge_key: freq for edge_key, freq in edge_frequencies.items() + if freq >= min_interaction_threshold + } + + logger.info( + f"Processed {interactions_in_window}/{interaction_count} interactions in time window" + ) + logger.info( + f"Found {len(filtered_node_frequencies)} nodes and {len(filtered_edge_frequencies)} edges " + f"above threshold (min: {min_interaction_threshold})" + ) + logger.info(f"Element type distribution: {relationship_type_frequencies}") + + return { + 'node_frequencies': filtered_node_frequencies, + 'edge_frequencies': filtered_edge_frequencies, + 'element_type_frequencies': relationship_type_frequencies, + 'total_interactions': interaction_count, + 'interactions_in_window': interactions_in_window, + 'time_window_days': time_window.days, + 'last_processed_timestamp': current_time.isoformat(), + 'cutoff_timestamp': cutoff_time.isoformat() + } + + +async def add_frequency_weights( + graph_adapter: GraphDBInterface, + usage_frequencies: Dict[str, Any] +) -> None: + """ + Add frequency weights to graph nodes and edges using the graph adapter. + + Uses direct Cypher queries for Neo4j adapter compatibility. + Writes frequency_weight properties back to the graph for use in: + - Ranking frequently referenced entities higher during retrieval + - Adjusting scoring for completion strategies + - Exposing usage metrics in dashboards or audits + + :param graph_adapter: Graph database adapter interface + :param usage_frequencies: Calculated usage frequencies from extract_usage_frequency + """ + node_frequencies = usage_frequencies.get('node_frequencies', {}) + edge_frequencies = usage_frequencies.get('edge_frequencies', {}) + + logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes") + + # Check adapter type and use appropriate method + adapter_type = type(graph_adapter).__name__ + logger.info(f"Using adapter: {adapter_type}") + + nodes_updated = 0 + nodes_failed = 0 + + # Determine which method to use based on adapter type + use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query') + use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query') + use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties') + + # Method 1: Neo4j Cypher with SET (creates properties on the fly) + if use_neo4j_cypher: + try: + logger.info("Using Neo4j Cypher SET method") + last_updated = usage_frequencies.get('last_processed_timestamp') + + for node_id, frequency in node_frequencies.items(): + try: + query = """ + MATCH (n) + WHERE n.id = $node_id + SET n.frequency_weight = $frequency, + n.frequency_updated_at = $updated_at + RETURN n.id as id + """ + + result = await graph_adapter.query( + query, + params={ + 'node_id': node_id, + 'frequency': frequency, + 'updated_at': last_updated + } + ) + + if result and len(result) > 0: + nodes_updated += 1 + else: + logger.warning(f"Node {node_id} not found or not updated") + nodes_failed += 1 + + except Exception as e: + logger.error(f"Error updating node {node_id}: {e}") + nodes_failed += 1 + + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") + + except Exception as e: + logger.error(f"Neo4j Cypher update failed: {e}") + use_neo4j_cypher = False + + # Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID) + elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'): + logger.info("Using Kuzu get_node + add_node method") + last_updated = usage_frequencies.get('last_processed_timestamp') + + for node_id, frequency in node_frequencies.items(): + try: + # Get the existing node (returns a dict) + existing_node_dict = await graph_adapter.get_node(node_id) + + if existing_node_dict: + # Update the dict with new properties + existing_node_dict['frequency_weight'] = frequency + existing_node_dict['frequency_updated_at'] = last_updated + + # Kuzu's add_node likely just takes the dict directly, not a Node object + # Try passing the dict directly first + try: + await graph_adapter.add_node(existing_node_dict) + nodes_updated += 1 + except Exception as dict_error: + # If dict doesn't work, try creating a Node object + logger.debug(f"Dict add failed, trying Node object: {dict_error}") + + try: + from cognee.infrastructure.engine import Node + # Try different Node constructor patterns + try: + # Pattern 1: Just properties + node_obj = Node(existing_node_dict) + except: + # Pattern 2: Type and properties + node_obj = Node( + type=existing_node_dict.get('type', 'Unknown'), + **existing_node_dict + ) + + await graph_adapter.add_node(node_obj) + nodes_updated += 1 + except Exception as node_error: + logger.error(f"Both dict and Node object failed: {node_error}") + nodes_failed += 1 + else: + logger.warning(f"Node {node_id} not found in graph") + nodes_failed += 1 + + except Exception as e: + logger.error(f"Error updating node {node_id}: {e}") + nodes_failed += 1 + + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") + + # Method 3: Generic get_node_by_id + update_node_properties + elif use_get_update: + logger.info("Using get/update method for adapter") + for node_id, frequency in node_frequencies.items(): + try: + # Get current node data + node_data = await graph_adapter.get_node_by_id(node_id) + + if node_data: + # Tweak the properties dict - add frequency_weight + if isinstance(node_data, dict): + properties = node_data.get('properties', {}) + else: + properties = getattr(node_data, 'properties', {}) or {} + + # Update with frequency weight + properties['frequency_weight'] = frequency + properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') + + # Write back via adapter + await graph_adapter.update_node_properties(node_id, properties) + nodes_updated += 1 + else: + logger.warning(f"Node {node_id} not found in graph") + nodes_failed += 1 + + except Exception as e: + logger.error(f"Error updating node {node_id}: {e}") + nodes_failed += 1 + + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") + for node_id, frequency in node_frequencies.items(): + try: + # Get current node data + node_data = await graph_adapter.get_node_by_id(node_id) + + if node_data: + # Tweak the properties dict - add frequency_weight + if isinstance(node_data, dict): + properties = node_data.get('properties', {}) + else: + properties = getattr(node_data, 'properties', {}) or {} + + # Update with frequency weight + properties['frequency_weight'] = frequency + properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') + + # Write back via adapter + await graph_adapter.update_node_properties(node_id, properties) + nodes_updated += 1 + else: + logger.warning(f"Node {node_id} not found in graph") + nodes_failed += 1 + + except Exception as e: + logger.error(f"Error updating node {node_id}: {e}") + nodes_failed += 1 + + # If no method is available + if not use_neo4j_cypher and not use_kuzu_query and not use_get_update: + logger.error(f"Adapter {adapter_type} does not support required update methods") + logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'") + return + + # Update edge frequencies + # Note: Edge property updates are backend-specific + if edge_frequencies: + logger.info(f"Processing {len(edge_frequencies)} edge frequency entries") + + edges_updated = 0 + edges_failed = 0 + + for edge_key, frequency in edge_frequencies.items(): + try: + # Parse edge key: "relationship_type:source_id:target_id" + parts = edge_key.split(':', 2) + if len(parts) == 3: + relationship_type, source_id, target_id = parts + + # Try to update edge if adapter supports it + if hasattr(graph_adapter, 'update_edge_properties'): + edge_properties = { + 'frequency_weight': frequency, + 'frequency_updated_at': usage_frequencies.get('last_processed_timestamp') + } + + await graph_adapter.update_edge_properties( + source_id, + target_id, + relationship_type, + edge_properties + ) + edges_updated += 1 + else: + # Fallback: store in metadata or log + logger.debug( + f"Adapter doesn't support update_edge_properties for " + f"{relationship_type} ({source_id} -> {target_id})" + ) + + except Exception as e: + logger.error(f"Error updating edge {edge_key}: {e}") + edges_failed += 1 + + if edges_updated > 0: + logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed") + else: + logger.info( + "Edge frequency updates skipped (adapter may not support edge property updates)" + ) + + # Store aggregate statistics as metadata if supported + if hasattr(graph_adapter, 'set_metadata'): + try: + metadata = { + 'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}), + 'total_interactions': usage_frequencies.get('total_interactions', 0), + 'interactions_in_window': usage_frequencies.get('interactions_in_window', 0), + 'last_frequency_update': usage_frequencies.get('last_processed_timestamp') + } + await graph_adapter.set_metadata('usage_frequency_stats', metadata) + logger.info("Stored usage frequency statistics as metadata") + except Exception as e: + logger.warning(f"Could not store usage statistics as metadata: {e}") + + +async def create_usage_frequency_pipeline( + graph_adapter: GraphDBInterface, + time_window: timedelta = timedelta(days=7), + min_interaction_threshold: int = 1, + batch_size: int = 100 +) -> tuple: + """ + Create memify pipeline entry for usage frequency tracking. + + This follows the same pattern as feedback enrichment flows, allowing + the frequency update to run end-to-end in a custom memify pipeline. + + Use case example: + extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline( + graph_adapter=my_adapter, + time_window=timedelta(days=30), + min_interaction_threshold=2 + ) + + # Run in memify pipeline + pipeline = Pipeline(extraction_tasks + enrichment_tasks) + results = await pipeline.run() + + :param graph_adapter: Graph database adapter + :param time_window: Time window for counting interactions (default: 7 days) + :param min_interaction_threshold: Minimum interactions to track (default: 1) + :param batch_size: Batch size for processing (default: 100) + :return: Tuple of (extraction_tasks, enrichment_tasks) + """ + logger.info("Creating usage frequency pipeline") + logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}") + + extraction_tasks = [ + Task( + extract_usage_frequency, + time_window=time_window, + min_interaction_threshold=min_interaction_threshold + ) + ] + + enrichment_tasks = [ + Task( + add_frequency_weights, + graph_adapter=graph_adapter, + task_config={"batch_size": batch_size} + ) + ] + + return extraction_tasks, enrichment_tasks + + +async def run_usage_frequency_update( + graph_adapter: GraphDBInterface, + subgraphs: List[CogneeGraph], + time_window: timedelta = timedelta(days=7), + min_interaction_threshold: int = 1 +) -> Dict[str, Any]: + """ + Convenience function to run the complete usage frequency update pipeline. + + This is the main entry point for updating frequency weights on graph elements + based on CogneeUserInteraction data from cognee.search(save_interaction=True). + + Example usage: + # After running searches with save_interaction=True + from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update + + # Get the graph with interactions + graph = await get_cognee_graph_with_interactions() + + # Update frequency weights + stats = await run_usage_frequency_update( + graph_adapter=graph_adapter, + subgraphs=[graph], + time_window=timedelta(days=30), # Last 30 days + min_interaction_threshold=2 # At least 2 uses + ) + + print(f"Updated {len(stats['node_frequencies'])} nodes") + + :param graph_adapter: Graph database adapter + :param subgraphs: List of CogneeGraph instances with interaction data + :param time_window: Time window for counting interactions + :param min_interaction_threshold: Minimum interactions to track + :return: Usage frequency statistics + """ + logger.info("Starting usage frequency update") + + try: + # Extract frequencies from interaction data + usage_frequencies = await extract_usage_frequency( + subgraphs=subgraphs, + time_window=time_window, + min_interaction_threshold=min_interaction_threshold + ) + + # Add frequency weights back to the graph + await add_frequency_weights( + graph_adapter=graph_adapter, + usage_frequencies=usage_frequencies + ) + + logger.info("Usage frequency update completed successfully") + logger.info( + f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, " + f"{len(usage_frequencies['node_frequencies'])} nodes weighted" + ) + + return usage_frequencies + + except Exception as e: + logger.error(f"Error during usage frequency update: {str(e)}") + raise + + +async def get_most_frequent_elements( + graph_adapter: GraphDBInterface, + top_n: int = 10, + element_type: Optional[str] = None +) -> List[Dict[str, Any]]: + """ + Retrieve the most frequently accessed graph elements. + + Useful for analytics dashboards and understanding user behavior. + + :param graph_adapter: Graph database adapter + :param top_n: Number of top elements to return + :param element_type: Optional filter by element type + :return: List of elements with their frequency weights + """ + logger.info(f"Retrieving top {top_n} most frequent elements") + + # This would need to be implemented based on the specific graph adapter's query capabilities + # Pseudocode: + # results = await graph_adapter.query_nodes_by_property( + # property_name='frequency_weight', + # order_by='DESC', + # limit=top_n, + # filters={'type': element_type} if element_type else None + # ) + + logger.warning("get_most_frequent_elements needs adapter-specific implementation") + return [] \ No newline at end of file diff --git a/cognee/tests/test_extract_usage_frequency.py b/cognee/tests/test_extract_usage_frequency.py new file mode 100644 index 000000000..c4a3e0448 --- /dev/null +++ b/cognee/tests/test_extract_usage_frequency.py @@ -0,0 +1,313 @@ +""" +Test Suite: Usage Frequency Tracking + +Comprehensive tests for the usage frequency tracking implementation. +Tests cover extraction logic, adapter integration, edge cases, and end-to-end workflows. + +Run with: + pytest test_usage_frequency_comprehensive.py -v + +Or without pytest: + python test_usage_frequency_comprehensive.py +""" + +import asyncio +import unittest +from datetime import datetime, timedelta +from typing import List, Dict + +# Mock imports for testing without full Cognee setup +try: + from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge + from cognee.tasks.memify.extract_usage_frequency import ( + extract_usage_frequency, + add_frequency_weights, + run_usage_frequency_update + ) + COGNEE_AVAILABLE = True +except ImportError: + COGNEE_AVAILABLE = False + print("⚠ Cognee not fully available - some tests will be skipped") + + +class TestUsageFrequencyExtraction(unittest.TestCase): + """Test the core frequency extraction logic.""" + + def setUp(self): + """Set up test fixtures.""" + if not COGNEE_AVAILABLE: + self.skipTest("Cognee modules not available") + + def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5): + """Create a mock graph with interactions and elements.""" + graph = CogneeGraph() + + # Create interaction nodes + current_time = datetime.now() + for i in range(num_interactions): + interaction_node = Node( + id=f"interaction_{i}", + node_type="CogneeUserInteraction", + attributes={ + 'type': 'CogneeUserInteraction', + 'query_text': f'Test query {i}', + 'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000) + } + ) + graph.add_node(interaction_node) + + # Create graph element nodes + for i in range(num_elements): + element_node = Node( + id=f"element_{i}", + node_type="DocumentChunk", + attributes={ + 'type': 'DocumentChunk', + 'text': f'Element content {i}' + } + ) + graph.add_node(element_node) + + # Create usage edges (interactions reference elements) + for i in range(num_interactions): + # Each interaction uses 2-3 elements + for j in range(2): + element_idx = (i + j) % num_elements + edge = Edge( + node1=graph.get_node(f"interaction_{i}"), + node2=graph.get_node(f"element_{element_idx}"), + edge_type="used_graph_element_to_answer", + attributes={'relationship_type': 'used_graph_element_to_answer'} + ) + graph.add_edge(edge) + + return graph + + async def test_basic_frequency_extraction(self): + """Test basic frequency extraction with simple graph.""" + graph = self.create_mock_graph(num_interactions=3, num_elements=5) + + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7), + min_interaction_threshold=1 + ) + + self.assertIn('node_frequencies', result) + self.assertIn('total_interactions', result) + self.assertEqual(result['total_interactions'], 3) + self.assertGreater(len(result['node_frequencies']), 0) + + async def test_time_window_filtering(self): + """Test that time window correctly filters old interactions.""" + graph = CogneeGraph() + + current_time = datetime.now() + + # Add recent interaction (within window) + recent_node = Node( + id="recent_interaction", + node_type="CogneeUserInteraction", + attributes={ + 'type': 'CogneeUserInteraction', + 'timestamp': int(current_time.timestamp() * 1000) + } + ) + graph.add_node(recent_node) + + # Add old interaction (outside window) + old_node = Node( + id="old_interaction", + node_type="CogneeUserInteraction", + attributes={ + 'type': 'CogneeUserInteraction', + 'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000) + } + ) + graph.add_node(old_node) + + # Add element + element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) + graph.add_node(element) + + # Add edges + graph.add_edge(Edge( + node1=recent_node, node2=element, + edge_type="used_graph_element_to_answer", + attributes={'relationship_type': 'used_graph_element_to_answer'} + )) + graph.add_edge(Edge( + node1=old_node, node2=element, + edge_type="used_graph_element_to_answer", + attributes={'relationship_type': 'used_graph_element_to_answer'} + )) + + # Extract with 7-day window + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7), + min_interaction_threshold=1 + ) + + # Should only count recent interaction + self.assertEqual(result['interactions_in_window'], 1) + self.assertEqual(result['total_interactions'], 2) + + async def test_threshold_filtering(self): + """Test that minimum threshold filters low-frequency nodes.""" + graph = self.create_mock_graph(num_interactions=5, num_elements=10) + + # Extract with threshold of 3 + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7), + min_interaction_threshold=3 + ) + + # Only nodes with 3+ accesses should be included + for node_id, freq in result['node_frequencies'].items(): + self.assertGreaterEqual(freq, 3) + + async def test_element_type_tracking(self): + """Test that element types are properly tracked.""" + graph = CogneeGraph() + + # Create interaction + interaction = Node( + id="interaction_1", + node_type="CogneeUserInteraction", + attributes={ + 'type': 'CogneeUserInteraction', + 'timestamp': int(datetime.now().timestamp() * 1000) + } + ) + graph.add_node(interaction) + + # Create elements of different types + chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) + entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'}) + + graph.add_node(chunk) + graph.add_node(entity) + + # Add edges + for element in [chunk, entity]: + graph.add_edge(Edge( + node1=interaction, node2=element, + edge_type="used_graph_element_to_answer", + attributes={'relationship_type': 'used_graph_element_to_answer'} + )) + + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7) + ) + + # Check element types were tracked + self.assertIn('element_type_frequencies', result) + types = result['element_type_frequencies'] + self.assertIn('DocumentChunk', types) + self.assertIn('Entity', types) + + async def test_empty_graph(self): + """Test handling of empty graph.""" + graph = CogneeGraph() + + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7) + ) + + self.assertEqual(result['total_interactions'], 0) + self.assertEqual(len(result['node_frequencies']), 0) + + async def test_no_interactions_in_window(self): + """Test handling when all interactions are outside time window.""" + graph = CogneeGraph() + + # Add old interaction + old_time = datetime.now() - timedelta(days=30) + old_interaction = Node( + id="old_interaction", + node_type="CogneeUserInteraction", + attributes={ + 'type': 'CogneeUserInteraction', + 'timestamp': int(old_time.timestamp() * 1000) + } + ) + graph.add_node(old_interaction) + + result = await extract_usage_frequency( + subgraphs=[graph], + time_window=timedelta(days=7) + ) + + self.assertEqual(result['interactions_in_window'], 0) + self.assertEqual(result['total_interactions'], 1) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete workflow.""" + + def setUp(self): + """Set up test fixtures.""" + if not COGNEE_AVAILABLE: + self.skipTest("Cognee modules not available") + + async def test_end_to_end_workflow(self): + """Test the complete end-to-end frequency tracking workflow.""" + # This would require a full Cognee setup with database + # Skipped in unit tests, run as part of example_usage_frequency_e2e.py + self.skipTest("E2E test - run example_usage_frequency_e2e.py instead") + + +# ============================================================================ +# Test Runner +# ============================================================================ + +def run_async_test(test_func): + """Helper to run async test functions.""" + asyncio.run(test_func()) + + +def main(): + """Run all tests.""" + if not COGNEE_AVAILABLE: + print("⚠ Cognee not available - skipping tests") + print("Install with: pip install cognee[neo4j]") + return + + print("=" * 80) + print("Running Usage Frequency Tests") + print("=" * 80) + print() + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add tests + suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction)) + suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Summary + print() + print("=" * 80) + print("Test Summary") + print("=" * 80) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + print(f"Skipped: {len(result.skipped)}") + + return 0 if result.wasSuccessful() else 1 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/examples/python/extract_usage_frequency_example.py b/examples/python/extract_usage_frequency_example.py new file mode 100644 index 000000000..3e39886a7 --- /dev/null +++ b/examples/python/extract_usage_frequency_example.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 +""" +End-to-End Example: Usage Frequency Tracking in Cognee + +This example demonstrates the complete workflow for tracking and analyzing +how frequently different graph elements are accessed through user searches. + +Features demonstrated: +- Setting up a knowledge base +- Running searches with interaction tracking (save_interaction=True) +- Extracting usage frequencies from interaction data +- Applying frequency weights to graph nodes +- Analyzing and visualizing the results + +Use cases: +- Ranking search results by popularity +- Identifying "hot topics" in your knowledge base +- Understanding user behavior and interests +- Improving retrieval based on usage patterns +""" + +import asyncio +import os +from datetime import timedelta +from typing import List, Dict, Any +from dotenv import load_dotenv + +import cognee +from cognee.api.v1.search import SearchType +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update + +# Load environment variables +load_dotenv() + + +# ============================================================================ +# STEP 1: Setup and Configuration +# ============================================================================ + +async def setup_knowledge_base(): + """ + Create a fresh knowledge base with sample content. + + In a real application, you would: + - Load documents from files, databases, or APIs + - Process larger datasets + - Organize content by datasets/categories + """ + print("=" * 80) + print("STEP 1: Setting up knowledge base") + print("=" * 80) + + # Reset state for clean demo (optional in production) + print("\nResetting Cognee state...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("✓ Reset complete") + + # Sample content: AI/ML educational material + documents = [ + """ + Machine Learning Fundamentals: + Machine learning is a subset of artificial intelligence that enables systems + to learn and improve from experience without being explicitly programmed. + The three main types are supervised learning, unsupervised learning, and + reinforcement learning. + """, + """ + Neural Networks Explained: + Neural networks are computing systems inspired by biological neural networks. + They consist of layers of interconnected nodes (neurons) that process information + through weighted connections. Deep learning uses neural networks with many layers + to automatically learn hierarchical representations of data. + """, + """ + Natural Language Processing: + NLP enables computers to understand, interpret, and generate human language. + Modern NLP uses transformer architectures like BERT and GPT, which have + revolutionized tasks such as translation, summarization, and question answering. + """, + """ + Computer Vision Applications: + Computer vision allows machines to interpret visual information from the world. + Convolutional neural networks (CNNs) are particularly effective for image + recognition, object detection, and image segmentation tasks. + """, + ] + + print(f"\nAdding {len(documents)} documents to knowledge base...") + await cognee.add(documents, dataset_name="ai_ml_fundamentals") + print("✓ Documents added") + + # Build knowledge graph + print("\nBuilding knowledge graph (cognify)...") + await cognee.cognify() + print("✓ Knowledge graph built") + + print("\n" + "=" * 80) + + +# ============================================================================ +# STEP 2: Simulate User Searches with Interaction Tracking +# ============================================================================ + +async def simulate_user_searches(queries: List[str]): + """ + Simulate users searching the knowledge base. + + The key parameter is save_interaction=True, which creates: + - CogneeUserInteraction nodes (one per search) + - used_graph_element_to_answer edges (connecting queries to relevant nodes) + + Args: + queries: List of search queries to simulate + + Returns: + Number of successful searches + """ + print("=" * 80) + print("STEP 2: Simulating user searches with interaction tracking") + print("=" * 80) + + successful_searches = 0 + + for i, query in enumerate(queries, 1): + print(f"\nSearch {i}/{len(queries)}: '{query}'") + try: + results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=query, + save_interaction=True, # ← THIS IS CRITICAL! + top_k=5 + ) + successful_searches += 1 + + # Show snippet of results + result_preview = str(results)[:100] if results else "No results" + print(f" ✓ Completed ({result_preview}...)") + + except Exception as e: + print(f" ✗ Failed: {e}") + + print(f"\n✓ Completed {successful_searches}/{len(queries)} searches") + print("=" * 80) + + return successful_searches + + +# ============================================================================ +# STEP 3: Extract and Apply Usage Frequencies +# ============================================================================ + +async def extract_and_apply_frequencies( + time_window_days: int = 7, + min_threshold: int = 1 +) -> Dict[str, Any]: + """ + Extract usage frequencies from interactions and apply them to the graph. + + This function: + 1. Retrieves the graph with interaction data + 2. Counts how often each node was accessed + 3. Writes frequency_weight property back to nodes + + Args: + time_window_days: Only count interactions from last N days + min_threshold: Minimum accesses to track (filter out rarely used nodes) + + Returns: + Dictionary with statistics about the frequency update + """ + print("=" * 80) + print("STEP 3: Extracting and applying usage frequencies") + print("=" * 80) + + # Get graph adapter + graph_engine = await get_graph_engine() + + # Retrieve graph with interactions + print("\nRetrieving graph from database...") + graph = CogneeGraph() + await graph.project_graph_from_db( + adapter=graph_engine, + node_properties_to_project=[ + "type", "node_type", "timestamp", "created_at", + "text", "name", "query_text", "frequency_weight" + ], + edge_properties_to_project=["relationship_type", "timestamp"], + directed=True, + ) + + print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges") + + # Count interaction nodes + interaction_nodes = [ + n for n in graph.nodes.values() + if n.attributes.get('type') == 'CogneeUserInteraction' or + n.attributes.get('node_type') == 'CogneeUserInteraction' + ] + print(f"✓ Found {len(interaction_nodes)} interaction nodes") + + # Run frequency extraction and update + print(f"\nExtracting frequencies (time window: {time_window_days} days)...") + stats = await run_usage_frequency_update( + graph_adapter=graph_engine, + subgraphs=[graph], + time_window=timedelta(days=time_window_days), + min_interaction_threshold=min_threshold + ) + + print(f"\n✓ Frequency extraction complete!") + print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}") + print(f" - Nodes weighted: {len(stats['node_frequencies'])}") + print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}") + + print("=" * 80) + + return stats + + +# ============================================================================ +# STEP 4: Analyze and Display Results +# ============================================================================ + +async def analyze_results(stats: Dict[str, Any]): + """ + Analyze and display the frequency tracking results. + + Shows: + - Top most frequently accessed nodes + - Element type distribution + - Verification that weights were written to database + + Args: + stats: Statistics from frequency extraction + """ + print("=" * 80) + print("STEP 4: Analyzing usage frequency results") + print("=" * 80) + + # Display top nodes by frequency + if stats['node_frequencies']: + print("\n📊 Top 10 Most Frequently Accessed Elements:") + print("-" * 80) + + sorted_nodes = sorted( + stats['node_frequencies'].items(), + key=lambda x: x[1], + reverse=True + ) + + # Get graph to display node details + graph_engine = await get_graph_engine() + graph = CogneeGraph() + await graph.project_graph_from_db( + adapter=graph_engine, + node_properties_to_project=["type", "text", "name"], + edge_properties_to_project=[], + directed=True, + ) + + for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1): + node = graph.get_node(node_id) + if node: + node_type = node.attributes.get('type', 'Unknown') + text = node.attributes.get('text') or node.attributes.get('name') or '' + text_preview = text[:60] + "..." if len(text) > 60 else text + + print(f"\n{i}. Frequency: {frequency} accesses") + print(f" Type: {node_type}") + print(f" Content: {text_preview}") + else: + print(f"\n{i}. Frequency: {frequency} accesses") + print(f" Node ID: {node_id[:50]}...") + + # Display element type distribution + if stats.get('element_type_frequencies'): + print("\n\n📈 Element Type Distribution:") + print("-" * 80) + type_dist = stats['element_type_frequencies'] + for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True): + print(f" {elem_type}: {count} accesses") + + # Verify weights in database (Neo4j only) + print("\n\n🔍 Verifying weights in database...") + print("-" * 80) + + graph_engine = await get_graph_engine() + adapter_type = type(graph_engine).__name__ + + if adapter_type == 'Neo4jAdapter': + try: + result = await graph_engine.query(""" + MATCH (n) + WHERE n.frequency_weight IS NOT NULL + RETURN count(n) as weighted_count + """) + + count = result[0]['weighted_count'] if result else 0 + if count > 0: + print(f"✓ {count} nodes have frequency_weight in Neo4j database") + + # Show sample + sample = await graph_engine.query(""" + MATCH (n) + WHERE n.frequency_weight IS NOT NULL + RETURN n.frequency_weight as weight, labels(n) as labels + ORDER BY n.frequency_weight DESC + LIMIT 3 + """) + + print("\nSample weighted nodes:") + for row in sample: + print(f" - Weight: {row['weight']}, Type: {row['labels']}") + else: + print("⚠ No nodes with frequency_weight found in database") + except Exception as e: + print(f"Could not verify in Neo4j: {e}") + else: + print(f"Database verification not implemented for {adapter_type}") + + print("\n" + "=" * 80) + + +# ============================================================================ +# STEP 5: Demonstrate Usage in Retrieval +# ============================================================================ + +async def demonstrate_retrieval_usage(): + """ + Demonstrate how frequency weights can be used in retrieval. + + Note: This is a conceptual demonstration. To actually use frequency + weights in ranking, you would need to modify the retrieval/completion + strategies to incorporate the frequency_weight property. + """ + print("=" * 80) + print("STEP 5: How to use frequency weights in retrieval") + print("=" * 80) + + print(""" + Frequency weights can be used to improve search results: + + 1. RANKING BOOST: + - Multiply relevance scores by frequency_weight + - Prioritize frequently accessed nodes in results + + 2. COMPLETION STRATEGIES: + - Adjust triplet importance based on usage + - Filter out rarely accessed information + + 3. ANALYTICS: + - Track trending topics over time + - Understand user interests and behavior + - Identify knowledge gaps (low-frequency nodes) + + 4. ADAPTIVE RETRIEVAL: + - Personalize results based on team usage patterns + - Surface popular answers faster + + Example Cypher query with frequency boost (Neo4j): + + MATCH (n) + WHERE n.text CONTAINS $search_term + RETURN n, n.frequency_weight as boost + ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC + LIMIT 10 + + To integrate this into Cognee, you would modify the completion + strategy to include frequency_weight in the scoring function. + """) + + print("=" * 80) + + +# ============================================================================ +# MAIN: Run Complete Example +# ============================================================================ + +async def main(): + """ + Run the complete end-to-end usage frequency tracking example. + """ + print("\n") + print("╔" + "=" * 78 + "╗") + print("║" + " " * 78 + "║") + print("║" + " Usage Frequency Tracking - End-to-End Example".center(78) + "║") + print("║" + " " * 78 + "║") + print("╚" + "=" * 78 + "╝") + print("\n") + + # Configuration check + print("Configuration:") + print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}") + print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}") + print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}") + + # Verify LLM key is set + if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here': + print("\n⚠ WARNING: LLM_API_KEY not set in .env file") + print(" Set your API key to run searches") + return + + print("\n") + + try: + # Step 1: Setup + await setup_knowledge_base() + + # Step 2: Simulate searches + # Note: Repeat queries increase frequency for those topics + queries = [ + "What is machine learning?", + "Explain neural networks", + "How does deep learning work?", + "Tell me about neural networks", # Repeat - increases frequency + "What are transformers in NLP?", + "Explain neural networks again", # Another repeat + "How does computer vision work?", + "What is reinforcement learning?", + "Tell me more about neural networks", # Third repeat + ] + + successful_searches = await simulate_user_searches(queries) + + if successful_searches == 0: + print("⚠ No searches completed - cannot demonstrate frequency tracking") + return + + # Step 3: Extract frequencies + stats = await extract_and_apply_frequencies( + time_window_days=7, + min_threshold=1 + ) + + # Step 4: Analyze results + await analyze_results(stats) + + # Step 5: Show usage examples + await demonstrate_retrieval_usage() + + # Summary + print("\n") + print("╔" + "=" * 78 + "╗") + print("║" + " " * 78 + "║") + print("║" + " Example Complete!".center(78) + "║") + print("║" + " " * 78 + "║") + print("╚" + "=" * 78 + "╝") + print("\n") + + print("Summary:") + print(f" ✓ Documents added: 4") + print(f" ✓ Searches performed: {successful_searches}") + print(f" ✓ Interactions tracked: {stats['interactions_in_window']}") + print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}") + + print("\nNext steps:") + print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph") + print(" 2. Modify retrieval strategies to use frequency_weight") + print(" 3. Build analytics dashboards using element_type_frequencies") + print(" 4. Run periodic frequency updates to track trends over time") + + print("\n") + + except Exception as e: + print(f"\n✗ Example failed: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file