diff --git a/cognee/tasks/memify/extract_usage_frequency.py b/cognee/tasks/memify/extract_usage_frequency.py index 7e437bd18..5d7dcde60 100644 --- a/cognee/tasks/memify/extract_usage_frequency.py +++ b/cognee/tasks/memify/extract_usage_frequency.py @@ -10,20 +10,20 @@ logger = get_logger("extract_usage_frequency") async def extract_usage_frequency( - subgraphs: List[CogneeGraph], + subgraphs: List[CogneeGraph], time_window: timedelta = timedelta(days=7), - min_interaction_threshold: int = 1 + min_interaction_threshold: int = 1, ) -> Dict[str, Any]: """ Extract usage frequency from CogneeUserInteraction nodes. - + When save_interaction=True in cognee.search(), the system creates: - CogneeUserInteraction nodes (representing the query/answer interaction) - used_graph_element_to_answer edges (connecting interactions to graph elements used) - + This function tallies how often each graph element is referenced via these edges, enabling frequency-based ranking in downstream retrievers. - + :param subgraphs: List of CogneeGraph instances containing interaction data :param time_window: Time window to consider for interactions (default: 7 days) :param min_interaction_threshold: Minimum interactions to track (default: 1) @@ -31,33 +31,35 @@ async def extract_usage_frequency( """ current_time = datetime.now() cutoff_time = current_time - time_window - + # Track frequencies for graph elements (nodes and edges) node_frequencies = {} edge_frequencies = {} relationship_type_frequencies = {} - + # Track interaction metadata interaction_count = 0 interactions_in_window = 0 - + logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs") logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}") - + for subgraph in subgraphs: # Find all CogneeUserInteraction nodes interaction_nodes = {} for node_id, node in subgraph.nodes.items(): - node_type = node.attributes.get('type') or node.attributes.get('node_type') - - if node_type == 'CogneeUserInteraction': + node_type = node.attributes.get("type") or node.attributes.get("node_type") + + if node_type == "CogneeUserInteraction": # Parse and validate timestamp - timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at') + timestamp_value = node.attributes.get("timestamp") or node.attributes.get( + "created_at" + ) if timestamp_value is not None: try: # Handle various timestamp formats interaction_time = None - + if isinstance(timestamp_value, datetime): # Already a Python datetime interaction_time = timestamp_value @@ -81,24 +83,24 @@ async def extract_usage_frequency( else: # ISO format string interaction_time = datetime.fromisoformat(timestamp_value) - elif hasattr(timestamp_value, 'to_native'): + elif hasattr(timestamp_value, "to_native"): # Neo4j datetime object - convert to Python datetime interaction_time = timestamp_value.to_native() - elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'): + elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"): # Datetime-like object - extract components try: interaction_time = datetime( year=timestamp_value.year, month=timestamp_value.month, day=timestamp_value.day, - hour=getattr(timestamp_value, 'hour', 0), - minute=getattr(timestamp_value, 'minute', 0), - second=getattr(timestamp_value, 'second', 0), - microsecond=getattr(timestamp_value, 'microsecond', 0) + hour=getattr(timestamp_value, "hour", 0), + minute=getattr(timestamp_value, "minute", 0), + second=getattr(timestamp_value, "second", 0), + microsecond=getattr(timestamp_value, "microsecond", 0), ) except (AttributeError, ValueError): pass - + if interaction_time is None: # Last resort: try converting to string and parsing str_value = str(timestamp_value) @@ -110,73 +112,83 @@ async def extract_usage_frequency( interaction_time = datetime.fromtimestamp(ts_int) else: interaction_time = datetime.fromisoformat(str_value) - + if interaction_time is None: raise ValueError(f"Could not parse timestamp: {timestamp_value}") - + # Make sure it's timezone-naive for comparison if interaction_time.tzinfo is not None: interaction_time = interaction_time.replace(tzinfo=None) - + interaction_nodes[node_id] = { - 'node': node, - 'timestamp': interaction_time, - 'in_window': interaction_time >= cutoff_time + "node": node, + "timestamp": interaction_time, + "in_window": interaction_time >= cutoff_time, } interaction_count += 1 if interaction_time >= cutoff_time: interactions_in_window += 1 except (ValueError, TypeError, AttributeError, OSError) as e: - logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}") - logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}") - + logger.warning( + f"Failed to parse timestamp for interaction node {node_id}: {e}" + ) + logger.debug( + f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}" + ) + # Process edges to find graph elements used in interactions for edge in subgraph.edges: - relationship_type = edge.attributes.get('relationship_type') - + relationship_type = edge.attributes.get("relationship_type") + # Look for 'used_graph_element_to_answer' edges - if relationship_type == 'used_graph_element_to_answer': + if relationship_type == "used_graph_element_to_answer": # node1 should be the CogneeUserInteraction, node2 is the graph element source_id = str(edge.node1.id) target_id = str(edge.node2.id) - + # Check if source is an interaction node in our time window if source_id in interaction_nodes: interaction_data = interaction_nodes[source_id] - - if interaction_data['in_window']: + + if interaction_data["in_window"]: # Count the graph element (target node) being used node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1 - + # Also track what type of element it is for analytics target_node = subgraph.get_node(target_id) if target_node: - element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type') + element_type = target_node.attributes.get( + "type" + ) or target_node.attributes.get("node_type") if element_type: - relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1 - + relationship_type_frequencies[element_type] = ( + relationship_type_frequencies.get(element_type, 0) + 1 + ) + # Also track general edge usage patterns - elif relationship_type and relationship_type != 'used_graph_element_to_answer': + elif relationship_type and relationship_type != "used_graph_element_to_answer": # Check if either endpoint is referenced in a recent interaction source_id = str(edge.node1.id) target_id = str(edge.node2.id) - + # If this edge connects to any frequently accessed nodes, track the edge type if source_id in node_frequencies or target_id in node_frequencies: edge_key = f"{relationship_type}:{source_id}:{target_id}" edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1 - + # Filter frequencies above threshold filtered_node_frequencies = { - node_id: freq for node_id, freq in node_frequencies.items() + node_id: freq + for node_id, freq in node_frequencies.items() if freq >= min_interaction_threshold } - + filtered_edge_frequencies = { - edge_key: freq for edge_key, freq in edge_frequencies.items() + edge_key: freq + for edge_key, freq in edge_frequencies.items() if freq >= min_interaction_threshold } - + logger.info( f"Processed {interactions_in_window}/{interaction_count} interactions in time window" ) @@ -185,58 +197,59 @@ async def extract_usage_frequency( f"above threshold (min: {min_interaction_threshold})" ) logger.info(f"Element type distribution: {relationship_type_frequencies}") - + return { - 'node_frequencies': filtered_node_frequencies, - 'edge_frequencies': filtered_edge_frequencies, - 'element_type_frequencies': relationship_type_frequencies, - 'total_interactions': interaction_count, - 'interactions_in_window': interactions_in_window, - 'time_window_days': time_window.days, - 'last_processed_timestamp': current_time.isoformat(), - 'cutoff_timestamp': cutoff_time.isoformat() + "node_frequencies": filtered_node_frequencies, + "edge_frequencies": filtered_edge_frequencies, + "element_type_frequencies": relationship_type_frequencies, + "total_interactions": interaction_count, + "interactions_in_window": interactions_in_window, + "time_window_days": time_window.days, + "last_processed_timestamp": current_time.isoformat(), + "cutoff_timestamp": cutoff_time.isoformat(), } async def add_frequency_weights( - graph_adapter: GraphDBInterface, - usage_frequencies: Dict[str, Any] + graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any] ) -> None: """ Add frequency weights to graph nodes and edges using the graph adapter. - + Uses direct Cypher queries for Neo4j adapter compatibility. Writes frequency_weight properties back to the graph for use in: - Ranking frequently referenced entities higher during retrieval - Adjusting scoring for completion strategies - Exposing usage metrics in dashboards or audits - + :param graph_adapter: Graph database adapter interface :param usage_frequencies: Calculated usage frequencies from extract_usage_frequency """ - node_frequencies = usage_frequencies.get('node_frequencies', {}) - edge_frequencies = usage_frequencies.get('edge_frequencies', {}) - + node_frequencies = usage_frequencies.get("node_frequencies", {}) + edge_frequencies = usage_frequencies.get("edge_frequencies", {}) + logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes") - + # Check adapter type and use appropriate method adapter_type = type(graph_adapter).__name__ logger.info(f"Using adapter: {adapter_type}") - + nodes_updated = 0 nodes_failed = 0 - + # Determine which method to use based on adapter type - use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query') - use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query') - use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties') - + use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query") + use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query") + use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr( + graph_adapter, "update_node_properties" + ) + # Method 1: Neo4j Cypher with SET (creates properties on the fly) if use_neo4j_cypher: try: logger.info("Using Neo4j Cypher SET method") - last_updated = usage_frequencies.get('last_processed_timestamp') - + last_updated = usage_frequencies.get("last_processed_timestamp") + for node_id, frequency in node_frequencies.items(): try: query = """ @@ -246,47 +259,49 @@ async def add_frequency_weights( n.frequency_updated_at = $updated_at RETURN n.id as id """ - + result = await graph_adapter.query( query, params={ - 'node_id': node_id, - 'frequency': frequency, - 'updated_at': last_updated - } + "node_id": node_id, + "frequency": frequency, + "updated_at": last_updated, + }, ) - + if result and len(result) > 0: nodes_updated += 1 else: logger.warning(f"Node {node_id} not found or not updated") nodes_failed += 1 - + except Exception as e: logger.error(f"Error updating node {node_id}: {e}") nodes_failed += 1 - + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") - + except Exception as e: logger.error(f"Neo4j Cypher update failed: {e}") use_neo4j_cypher = False - + # Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID) - elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'): + elif ( + use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node") + ): logger.info("Using Kuzu get_node + add_node method") - last_updated = usage_frequencies.get('last_processed_timestamp') - + last_updated = usage_frequencies.get("last_processed_timestamp") + for node_id, frequency in node_frequencies.items(): try: # Get the existing node (returns a dict) existing_node_dict = await graph_adapter.get_node(node_id) - + if existing_node_dict: # Update the dict with new properties - existing_node_dict['frequency_weight'] = frequency - existing_node_dict['frequency_updated_at'] = last_updated - + existing_node_dict["frequency_weight"] = frequency + existing_node_dict["frequency_updated_at"] = last_updated + # Kuzu's add_node likely just takes the dict directly, not a Node object # Try passing the dict directly first try: @@ -295,20 +310,21 @@ async def add_frequency_weights( except Exception as dict_error: # If dict doesn't work, try creating a Node object logger.debug(f"Dict add failed, trying Node object: {dict_error}") - + try: from cognee.infrastructure.engine import Node + # Try different Node constructor patterns try: # Pattern 1: Just properties node_obj = Node(existing_node_dict) - except: + except Exception: # Pattern 2: Type and properties node_obj = Node( - type=existing_node_dict.get('type', 'Unknown'), - **existing_node_dict + type=existing_node_dict.get("type", "Unknown"), + **existing_node_dict, ) - + await graph_adapter.add_node(node_obj) nodes_updated += 1 except Exception as node_error: @@ -317,13 +333,13 @@ async def add_frequency_weights( else: logger.warning(f"Node {node_id} not found in graph") nodes_failed += 1 - + except Exception as e: logger.error(f"Error updating node {node_id}: {e}") nodes_failed += 1 - + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") - + # Method 3: Generic get_node_by_id + update_node_properties elif use_get_update: logger.info("Using get/update method for adapter") @@ -331,90 +347,95 @@ async def add_frequency_weights( try: # Get current node data node_data = await graph_adapter.get_node_by_id(node_id) - + if node_data: # Tweak the properties dict - add frequency_weight if isinstance(node_data, dict): - properties = node_data.get('properties', {}) + properties = node_data.get("properties", {}) else: - properties = getattr(node_data, 'properties', {}) or {} - + properties = getattr(node_data, "properties", {}) or {} + # Update with frequency weight - properties['frequency_weight'] = frequency - properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') - + properties["frequency_weight"] = frequency + properties["frequency_updated_at"] = usage_frequencies.get( + "last_processed_timestamp" + ) + # Write back via adapter await graph_adapter.update_node_properties(node_id, properties) nodes_updated += 1 else: logger.warning(f"Node {node_id} not found in graph") nodes_failed += 1 - + except Exception as e: logger.error(f"Error updating node {node_id}: {e}") nodes_failed += 1 - + logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed") for node_id, frequency in node_frequencies.items(): try: # Get current node data node_data = await graph_adapter.get_node_by_id(node_id) - + if node_data: # Tweak the properties dict - add frequency_weight if isinstance(node_data, dict): - properties = node_data.get('properties', {}) + properties = node_data.get("properties", {}) else: - properties = getattr(node_data, 'properties', {}) or {} - + properties = getattr(node_data, "properties", {}) or {} + # Update with frequency weight - properties['frequency_weight'] = frequency - properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') - + properties["frequency_weight"] = frequency + properties["frequency_updated_at"] = usage_frequencies.get( + "last_processed_timestamp" + ) + # Write back via adapter await graph_adapter.update_node_properties(node_id, properties) nodes_updated += 1 else: logger.warning(f"Node {node_id} not found in graph") nodes_failed += 1 - + except Exception as e: logger.error(f"Error updating node {node_id}: {e}") nodes_failed += 1 - + # If no method is available if not use_neo4j_cypher and not use_kuzu_query and not use_get_update: logger.error(f"Adapter {adapter_type} does not support required update methods") - logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'") + logger.error( + "Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'" + ) return - + # Update edge frequencies # Note: Edge property updates are backend-specific if edge_frequencies: logger.info(f"Processing {len(edge_frequencies)} edge frequency entries") - + edges_updated = 0 edges_failed = 0 - + for edge_key, frequency in edge_frequencies.items(): try: # Parse edge key: "relationship_type:source_id:target_id" - parts = edge_key.split(':', 2) + parts = edge_key.split(":", 2) if len(parts) == 3: relationship_type, source_id, target_id = parts - + # Try to update edge if adapter supports it - if hasattr(graph_adapter, 'update_edge_properties'): + if hasattr(graph_adapter, "update_edge_properties"): edge_properties = { - 'frequency_weight': frequency, - 'frequency_updated_at': usage_frequencies.get('last_processed_timestamp') + "frequency_weight": frequency, + "frequency_updated_at": usage_frequencies.get( + "last_processed_timestamp" + ), } - + await graph_adapter.update_edge_properties( - source_id, - target_id, - relationship_type, - edge_properties + source_id, target_id, relationship_type, edge_properties ) edges_updated += 1 else: @@ -423,28 +444,28 @@ async def add_frequency_weights( f"Adapter doesn't support update_edge_properties for " f"{relationship_type} ({source_id} -> {target_id})" ) - + except Exception as e: logger.error(f"Error updating edge {edge_key}: {e}") edges_failed += 1 - + if edges_updated > 0: logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed") else: logger.info( "Edge frequency updates skipped (adapter may not support edge property updates)" ) - + # Store aggregate statistics as metadata if supported - if hasattr(graph_adapter, 'set_metadata'): + if hasattr(graph_adapter, "set_metadata"): try: metadata = { - 'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}), - 'total_interactions': usage_frequencies.get('total_interactions', 0), - 'interactions_in_window': usage_frequencies.get('interactions_in_window', 0), - 'last_frequency_update': usage_frequencies.get('last_processed_timestamp') + "element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}), + "total_interactions": usage_frequencies.get("total_interactions", 0), + "interactions_in_window": usage_frequencies.get("interactions_in_window", 0), + "last_frequency_update": usage_frequencies.get("last_processed_timestamp"), } - await graph_adapter.set_metadata('usage_frequency_stats', metadata) + await graph_adapter.set_metadata("usage_frequency_stats", metadata) logger.info("Stored usage frequency statistics as metadata") except Exception as e: logger.warning(f"Could not store usage statistics as metadata: {e}") @@ -454,25 +475,25 @@ async def create_usage_frequency_pipeline( graph_adapter: GraphDBInterface, time_window: timedelta = timedelta(days=7), min_interaction_threshold: int = 1, - batch_size: int = 100 + batch_size: int = 100, ) -> tuple: """ Create memify pipeline entry for usage frequency tracking. - + This follows the same pattern as feedback enrichment flows, allowing the frequency update to run end-to-end in a custom memify pipeline. - + Use case example: extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline( graph_adapter=my_adapter, time_window=timedelta(days=30), min_interaction_threshold=2 ) - + # Run in memify pipeline pipeline = Pipeline(extraction_tasks + enrichment_tasks) results = await pipeline.run() - + :param graph_adapter: Graph database adapter :param time_window: Time window for counting interactions (default: 7 days) :param min_interaction_threshold: Minimum interactions to track (default: 1) @@ -481,23 +502,23 @@ async def create_usage_frequency_pipeline( """ logger.info("Creating usage frequency pipeline") logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}") - + extraction_tasks = [ Task( extract_usage_frequency, time_window=time_window, - min_interaction_threshold=min_interaction_threshold + min_interaction_threshold=min_interaction_threshold, ) ] - + enrichment_tasks = [ Task( add_frequency_weights, graph_adapter=graph_adapter, - task_config={"batch_size": batch_size} + task_config={"batch_size": batch_size}, ) ] - + return extraction_tasks, enrichment_tasks @@ -505,21 +526,21 @@ async def run_usage_frequency_update( graph_adapter: GraphDBInterface, subgraphs: List[CogneeGraph], time_window: timedelta = timedelta(days=7), - min_interaction_threshold: int = 1 + min_interaction_threshold: int = 1, ) -> Dict[str, Any]: """ Convenience function to run the complete usage frequency update pipeline. - + This is the main entry point for updating frequency weights on graph elements based on CogneeUserInteraction data from cognee.search(save_interaction=True). - + Example usage: # After running searches with save_interaction=True from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update - + # Get the graph with interactions graph = await get_cognee_graph_with_interactions() - + # Update frequency weights stats = await run_usage_frequency_update( graph_adapter=graph_adapter, @@ -527,9 +548,9 @@ async def run_usage_frequency_update( time_window=timedelta(days=30), # Last 30 days min_interaction_threshold=2 # At least 2 uses ) - + print(f"Updated {len(stats['node_frequencies'])} nodes") - + :param graph_adapter: Graph database adapter :param subgraphs: List of CogneeGraph instances with interaction data :param time_window: Time window for counting interactions @@ -537,51 +558,48 @@ async def run_usage_frequency_update( :return: Usage frequency statistics """ logger.info("Starting usage frequency update") - + try: # Extract frequencies from interaction data usage_frequencies = await extract_usage_frequency( subgraphs=subgraphs, time_window=time_window, - min_interaction_threshold=min_interaction_threshold + min_interaction_threshold=min_interaction_threshold, ) - + # Add frequency weights back to the graph await add_frequency_weights( - graph_adapter=graph_adapter, - usage_frequencies=usage_frequencies + graph_adapter=graph_adapter, usage_frequencies=usage_frequencies ) - + logger.info("Usage frequency update completed successfully") logger.info( f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, " f"{len(usage_frequencies['node_frequencies'])} nodes weighted" ) - + return usage_frequencies - + except Exception as e: logger.error(f"Error during usage frequency update: {str(e)}") raise async def get_most_frequent_elements( - graph_adapter: GraphDBInterface, - top_n: int = 10, - element_type: Optional[str] = None + graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None ) -> List[Dict[str, Any]]: """ Retrieve the most frequently accessed graph elements. - + Useful for analytics dashboards and understanding user behavior. - + :param graph_adapter: Graph database adapter :param top_n: Number of top elements to return :param element_type: Optional filter by element type :return: List of elements with their frequency weights """ logger.info(f"Retrieving top {top_n} most frequent elements") - + # This would need to be implemented based on the specific graph adapter's query capabilities # Pseudocode: # results = await graph_adapter.query_nodes_by_property( @@ -590,6 +608,6 @@ async def get_most_frequent_elements( # limit=top_n, # filters={'type': element_type} if element_type else None # ) - + logger.warning("get_most_frequent_elements needs adapter-specific implementation") - return [] \ No newline at end of file + return [] diff --git a/cognee/tests/test_extract_usage_frequency.py b/cognee/tests/test_extract_usage_frequency.py index c4a3e0448..a4b12dd0d 100644 --- a/cognee/tests/test_extract_usage_frequency.py +++ b/cognee/tests/test_extract_usage_frequency.py @@ -6,7 +6,7 @@ Tests cover extraction logic, adapter integration, edge cases, and end-to-end wo Run with: pytest test_usage_frequency_comprehensive.py -v - + Or without pytest: python test_usage_frequency_comprehensive.py """ @@ -23,8 +23,9 @@ try: from cognee.tasks.memify.extract_usage_frequency import ( extract_usage_frequency, add_frequency_weights, - run_usage_frequency_update + run_usage_frequency_update, ) + COGNEE_AVAILABLE = True except ImportError: COGNEE_AVAILABLE = False @@ -33,16 +34,16 @@ except ImportError: class TestUsageFrequencyExtraction(unittest.TestCase): """Test the core frequency extraction logic.""" - + def setUp(self): """Set up test fixtures.""" if not COGNEE_AVAILABLE: self.skipTest("Cognee modules not available") - + def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5): """Create a mock graph with interactions and elements.""" graph = CogneeGraph() - + # Create interaction nodes current_time = datetime.now() for i in range(num_interactions): @@ -50,25 +51,22 @@ class TestUsageFrequencyExtraction(unittest.TestCase): id=f"interaction_{i}", node_type="CogneeUserInteraction", attributes={ - 'type': 'CogneeUserInteraction', - 'query_text': f'Test query {i}', - 'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000) - } + "type": "CogneeUserInteraction", + "query_text": f"Test query {i}", + "timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000), + }, ) graph.add_node(interaction_node) - + # Create graph element nodes for i in range(num_elements): element_node = Node( id=f"element_{i}", node_type="DocumentChunk", - attributes={ - 'type': 'DocumentChunk', - 'text': f'Element content {i}' - } + attributes={"type": "DocumentChunk", "text": f"Element content {i}"}, ) graph.add_node(element_node) - + # Create usage edges (interactions reference elements) for i in range(num_interactions): # Each interaction uses 2-3 elements @@ -78,183 +76,179 @@ class TestUsageFrequencyExtraction(unittest.TestCase): node1=graph.get_node(f"interaction_{i}"), node2=graph.get_node(f"element_{element_idx}"), edge_type="used_graph_element_to_answer", - attributes={'relationship_type': 'used_graph_element_to_answer'} + attributes={"relationship_type": "used_graph_element_to_answer"}, ) graph.add_edge(edge) - + return graph - + async def test_basic_frequency_extraction(self): """Test basic frequency extraction with simple graph.""" graph = self.create_mock_graph(num_interactions=3, num_elements=5) - + result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7), - min_interaction_threshold=1 + subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1 ) - - self.assertIn('node_frequencies', result) - self.assertIn('total_interactions', result) - self.assertEqual(result['total_interactions'], 3) - self.assertGreater(len(result['node_frequencies']), 0) - + + self.assertIn("node_frequencies", result) + self.assertIn("total_interactions", result) + self.assertEqual(result["total_interactions"], 3) + self.assertGreater(len(result["node_frequencies"]), 0) + async def test_time_window_filtering(self): """Test that time window correctly filters old interactions.""" graph = CogneeGraph() - + current_time = datetime.now() - + # Add recent interaction (within window) recent_node = Node( id="recent_interaction", node_type="CogneeUserInteraction", attributes={ - 'type': 'CogneeUserInteraction', - 'timestamp': int(current_time.timestamp() * 1000) - } + "type": "CogneeUserInteraction", + "timestamp": int(current_time.timestamp() * 1000), + }, ) graph.add_node(recent_node) - + # Add old interaction (outside window) old_node = Node( id="old_interaction", node_type="CogneeUserInteraction", attributes={ - 'type': 'CogneeUserInteraction', - 'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000) - } + "type": "CogneeUserInteraction", + "timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000), + }, ) graph.add_node(old_node) - + # Add element - element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) + element = Node( + id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"} + ) graph.add_node(element) - + # Add edges - graph.add_edge(Edge( - node1=recent_node, node2=element, - edge_type="used_graph_element_to_answer", - attributes={'relationship_type': 'used_graph_element_to_answer'} - )) - graph.add_edge(Edge( - node1=old_node, node2=element, - edge_type="used_graph_element_to_answer", - attributes={'relationship_type': 'used_graph_element_to_answer'} - )) - + graph.add_edge( + Edge( + node1=recent_node, + node2=element, + edge_type="used_graph_element_to_answer", + attributes={"relationship_type": "used_graph_element_to_answer"}, + ) + ) + graph.add_edge( + Edge( + node1=old_node, + node2=element, + edge_type="used_graph_element_to_answer", + attributes={"relationship_type": "used_graph_element_to_answer"}, + ) + ) + # Extract with 7-day window result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7), - min_interaction_threshold=1 + subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1 ) - + # Should only count recent interaction - self.assertEqual(result['interactions_in_window'], 1) - self.assertEqual(result['total_interactions'], 2) - + self.assertEqual(result["interactions_in_window"], 1) + self.assertEqual(result["total_interactions"], 2) + async def test_threshold_filtering(self): """Test that minimum threshold filters low-frequency nodes.""" graph = self.create_mock_graph(num_interactions=5, num_elements=10) - + # Extract with threshold of 3 result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7), - min_interaction_threshold=3 + subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3 ) - + # Only nodes with 3+ accesses should be included - for node_id, freq in result['node_frequencies'].items(): + for node_id, freq in result["node_frequencies"].items(): self.assertGreaterEqual(freq, 3) - + async def test_element_type_tracking(self): """Test that element types are properly tracked.""" graph = CogneeGraph() - + # Create interaction interaction = Node( id="interaction_1", node_type="CogneeUserInteraction", attributes={ - 'type': 'CogneeUserInteraction', - 'timestamp': int(datetime.now().timestamp() * 1000) - } + "type": "CogneeUserInteraction", + "timestamp": int(datetime.now().timestamp() * 1000), + }, ) graph.add_node(interaction) - + # Create elements of different types - chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) - entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'}) - + chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}) + entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"}) + graph.add_node(chunk) graph.add_node(entity) - + # Add edges for element in [chunk, entity]: - graph.add_edge(Edge( - node1=interaction, node2=element, - edge_type="used_graph_element_to_answer", - attributes={'relationship_type': 'used_graph_element_to_answer'} - )) - - result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7) - ) - + graph.add_edge( + Edge( + node1=interaction, + node2=element, + edge_type="used_graph_element_to_answer", + attributes={"relationship_type": "used_graph_element_to_answer"}, + ) + ) + + result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7)) + # Check element types were tracked - self.assertIn('element_type_frequencies', result) - types = result['element_type_frequencies'] - self.assertIn('DocumentChunk', types) - self.assertIn('Entity', types) - + self.assertIn("element_type_frequencies", result) + types = result["element_type_frequencies"] + self.assertIn("DocumentChunk", types) + self.assertIn("Entity", types) + async def test_empty_graph(self): """Test handling of empty graph.""" graph = CogneeGraph() - - result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7) - ) - - self.assertEqual(result['total_interactions'], 0) - self.assertEqual(len(result['node_frequencies']), 0) - + + result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7)) + + self.assertEqual(result["total_interactions"], 0) + self.assertEqual(len(result["node_frequencies"]), 0) + async def test_no_interactions_in_window(self): """Test handling when all interactions are outside time window.""" graph = CogneeGraph() - + # Add old interaction old_time = datetime.now() - timedelta(days=30) old_interaction = Node( id="old_interaction", node_type="CogneeUserInteraction", attributes={ - 'type': 'CogneeUserInteraction', - 'timestamp': int(old_time.timestamp() * 1000) - } + "type": "CogneeUserInteraction", + "timestamp": int(old_time.timestamp() * 1000), + }, ) graph.add_node(old_interaction) - - result = await extract_usage_frequency( - subgraphs=[graph], - time_window=timedelta(days=7) - ) - - self.assertEqual(result['interactions_in_window'], 0) - self.assertEqual(result['total_interactions'], 1) + + result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7)) + + self.assertEqual(result["interactions_in_window"], 0) + self.assertEqual(result["total_interactions"], 1) class TestIntegration(unittest.TestCase): """Integration tests for the complete workflow.""" - + def setUp(self): """Set up test fixtures.""" if not COGNEE_AVAILABLE: self.skipTest("Cognee modules not available") - + async def test_end_to_end_workflow(self): """Test the complete end-to-end frequency tracking workflow.""" # This would require a full Cognee setup with database @@ -266,6 +260,7 @@ class TestIntegration(unittest.TestCase): # Test Runner # ============================================================================ + def run_async_test(test_func): """Helper to run async test functions.""" asyncio.run(test_func()) @@ -277,24 +272,24 @@ def main(): print("⚠ Cognee not available - skipping tests") print("Install with: pip install cognee[neo4j]") return - + print("=" * 80) print("Running Usage Frequency Tests") print("=" * 80) print() - + # Create test suite loader = unittest.TestLoader() suite = unittest.TestSuite() - + # Add tests suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction)) suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) - + # Run tests runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) - + # Summary print() print("=" * 80) @@ -305,9 +300,9 @@ def main(): print(f"Failures: {len(result.failures)}") print(f"Errors: {len(result.errors)}") print(f"Skipped: {len(result.skipped)}") - + return 0 if result.wasSuccessful() else 1 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/examples/python/extract_usage_frequency_example.py b/examples/python/extract_usage_frequency_example.py index 3e39886a7..b1068ae38 100644 --- a/examples/python/extract_usage_frequency_example.py +++ b/examples/python/extract_usage_frequency_example.py @@ -39,10 +39,11 @@ load_dotenv() # STEP 1: Setup and Configuration # ============================================================================ + async def setup_knowledge_base(): """ Create a fresh knowledge base with sample content. - + In a real application, you would: - Load documents from files, databases, or APIs - Process larger datasets @@ -51,13 +52,13 @@ async def setup_knowledge_base(): print("=" * 80) print("STEP 1: Setting up knowledge base") print("=" * 80) - + # Reset state for clean demo (optional in production) print("\nResetting Cognee state...") await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) print("✓ Reset complete") - + # Sample content: AI/ML educational material documents = [ """ @@ -87,16 +88,16 @@ async def setup_knowledge_base(): recognition, object detection, and image segmentation tasks. """, ] - + print(f"\nAdding {len(documents)} documents to knowledge base...") await cognee.add(documents, dataset_name="ai_ml_fundamentals") print("✓ Documents added") - + # Build knowledge graph print("\nBuilding knowledge graph (cognify)...") await cognee.cognify() print("✓ Knowledge graph built") - + print("\n" + "=" * 80) @@ -104,26 +105,27 @@ async def setup_knowledge_base(): # STEP 2: Simulate User Searches with Interaction Tracking # ============================================================================ + async def simulate_user_searches(queries: List[str]): """ Simulate users searching the knowledge base. - + The key parameter is save_interaction=True, which creates: - CogneeUserInteraction nodes (one per search) - used_graph_element_to_answer edges (connecting queries to relevant nodes) - + Args: queries: List of search queries to simulate - + Returns: Number of successful searches """ print("=" * 80) print("STEP 2: Simulating user searches with interaction tracking") print("=" * 80) - + successful_searches = 0 - + for i, query in enumerate(queries, 1): print(f"\nSearch {i}/{len(queries)}: '{query}'") try: @@ -131,20 +133,20 @@ async def simulate_user_searches(queries: List[str]): query_type=SearchType.GRAPH_COMPLETION, query_text=query, save_interaction=True, # ← THIS IS CRITICAL! - top_k=5 + top_k=5, ) successful_searches += 1 - + # Show snippet of results result_preview = str(results)[:100] if results else "No results" print(f" ✓ Completed ({result_preview}...)") - + except Exception as e: print(f" ✗ Failed: {e}") - + print(f"\n✓ Completed {successful_searches}/{len(queries)} searches") print("=" * 80) - + return successful_searches @@ -152,71 +154,80 @@ async def simulate_user_searches(queries: List[str]): # STEP 3: Extract and Apply Usage Frequencies # ============================================================================ + async def extract_and_apply_frequencies( - time_window_days: int = 7, - min_threshold: int = 1 + time_window_days: int = 7, min_threshold: int = 1 ) -> Dict[str, Any]: """ Extract usage frequencies from interactions and apply them to the graph. - + This function: 1. Retrieves the graph with interaction data 2. Counts how often each node was accessed 3. Writes frequency_weight property back to nodes - + Args: time_window_days: Only count interactions from last N days min_threshold: Minimum accesses to track (filter out rarely used nodes) - + Returns: Dictionary with statistics about the frequency update """ print("=" * 80) print("STEP 3: Extracting and applying usage frequencies") print("=" * 80) - + # Get graph adapter graph_engine = await get_graph_engine() - + # Retrieve graph with interactions print("\nRetrieving graph from database...") graph = CogneeGraph() await graph.project_graph_from_db( adapter=graph_engine, node_properties_to_project=[ - "type", "node_type", "timestamp", "created_at", - "text", "name", "query_text", "frequency_weight" + "type", + "node_type", + "timestamp", + "created_at", + "text", + "name", + "query_text", + "frequency_weight", ], edge_properties_to_project=["relationship_type", "timestamp"], directed=True, ) - + print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges") - + # Count interaction nodes interaction_nodes = [ - n for n in graph.nodes.values() - if n.attributes.get('type') == 'CogneeUserInteraction' or - n.attributes.get('node_type') == 'CogneeUserInteraction' + n + for n in graph.nodes.values() + if n.attributes.get("type") == "CogneeUserInteraction" + or n.attributes.get("node_type") == "CogneeUserInteraction" ] print(f"✓ Found {len(interaction_nodes)} interaction nodes") - + # Run frequency extraction and update print(f"\nExtracting frequencies (time window: {time_window_days} days)...") stats = await run_usage_frequency_update( graph_adapter=graph_engine, subgraphs=[graph], time_window=timedelta(days=time_window_days), - min_interaction_threshold=min_threshold + min_interaction_threshold=min_threshold, + ) + + print("\n✓ Frequency extraction complete!") + print( + f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}" ) - - print(f"\n✓ Frequency extraction complete!") - print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}") print(f" - Nodes weighted: {len(stats['node_frequencies'])}") print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}") - + print("=" * 80) - + return stats @@ -224,33 +235,30 @@ async def extract_and_apply_frequencies( # STEP 4: Analyze and Display Results # ============================================================================ + async def analyze_results(stats: Dict[str, Any]): """ Analyze and display the frequency tracking results. - + Shows: - Top most frequently accessed nodes - Element type distribution - Verification that weights were written to database - + Args: stats: Statistics from frequency extraction """ print("=" * 80) print("STEP 4: Analyzing usage frequency results") print("=" * 80) - + # Display top nodes by frequency - if stats['node_frequencies']: + if stats["node_frequencies"]: print("\n📊 Top 10 Most Frequently Accessed Elements:") print("-" * 80) - - sorted_nodes = sorted( - stats['node_frequencies'].items(), - key=lambda x: x[1], - reverse=True - ) - + + sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True) + # Get graph to display node details graph_engine = await get_graph_engine() graph = CogneeGraph() @@ -260,48 +268,48 @@ async def analyze_results(stats: Dict[str, Any]): edge_properties_to_project=[], directed=True, ) - + for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1): node = graph.get_node(node_id) if node: - node_type = node.attributes.get('type', 'Unknown') - text = node.attributes.get('text') or node.attributes.get('name') or '' + node_type = node.attributes.get("type", "Unknown") + text = node.attributes.get("text") or node.attributes.get("name") or "" text_preview = text[:60] + "..." if len(text) > 60 else text - + print(f"\n{i}. Frequency: {frequency} accesses") print(f" Type: {node_type}") print(f" Content: {text_preview}") else: print(f"\n{i}. Frequency: {frequency} accesses") print(f" Node ID: {node_id[:50]}...") - + # Display element type distribution - if stats.get('element_type_frequencies'): + if stats.get("element_type_frequencies"): print("\n\n📈 Element Type Distribution:") print("-" * 80) - type_dist = stats['element_type_frequencies'] + type_dist = stats["element_type_frequencies"] for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True): print(f" {elem_type}: {count} accesses") - + # Verify weights in database (Neo4j only) print("\n\n🔍 Verifying weights in database...") print("-" * 80) - + graph_engine = await get_graph_engine() adapter_type = type(graph_engine).__name__ - - if adapter_type == 'Neo4jAdapter': + + if adapter_type == "Neo4jAdapter": try: result = await graph_engine.query(""" MATCH (n) WHERE n.frequency_weight IS NOT NULL RETURN count(n) as weighted_count """) - - count = result[0]['weighted_count'] if result else 0 + + count = result[0]["weighted_count"] if result else 0 if count > 0: print(f"✓ {count} nodes have frequency_weight in Neo4j database") - + # Show sample sample = await graph_engine.query(""" MATCH (n) @@ -310,7 +318,7 @@ async def analyze_results(stats: Dict[str, Any]): ORDER BY n.frequency_weight DESC LIMIT 3 """) - + print("\nSample weighted nodes:") for row in sample: print(f" - Weight: {row['weight']}, Type: {row['labels']}") @@ -320,7 +328,7 @@ async def analyze_results(stats: Dict[str, Any]): print(f"Could not verify in Neo4j: {e}") else: print(f"Database verification not implemented for {adapter_type}") - + print("\n" + "=" * 80) @@ -328,10 +336,11 @@ async def analyze_results(stats: Dict[str, Any]): # STEP 5: Demonstrate Usage in Retrieval # ============================================================================ + async def demonstrate_retrieval_usage(): """ Demonstrate how frequency weights can be used in retrieval. - + Note: This is a conceptual demonstration. To actually use frequency weights in ranking, you would need to modify the retrieval/completion strategies to incorporate the frequency_weight property. @@ -339,39 +348,39 @@ async def demonstrate_retrieval_usage(): print("=" * 80) print("STEP 5: How to use frequency weights in retrieval") print("=" * 80) - + print(""" Frequency weights can be used to improve search results: - + 1. RANKING BOOST: - Multiply relevance scores by frequency_weight - Prioritize frequently accessed nodes in results - + 2. COMPLETION STRATEGIES: - Adjust triplet importance based on usage - Filter out rarely accessed information - + 3. ANALYTICS: - Track trending topics over time - Understand user interests and behavior - Identify knowledge gaps (low-frequency nodes) - + 4. ADAPTIVE RETRIEVAL: - Personalize results based on team usage patterns - Surface popular answers faster - + Example Cypher query with frequency boost (Neo4j): - + MATCH (n) WHERE n.text CONTAINS $search_term RETURN n, n.frequency_weight as boost ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC LIMIT 10 - + To integrate this into Cognee, you would modify the completion strategy to include frequency_weight in the scoring function. """) - + print("=" * 80) @@ -379,6 +388,7 @@ async def demonstrate_retrieval_usage(): # MAIN: Run Complete Example # ============================================================================ + async def main(): """ Run the complete end-to-end usage frequency tracking example. @@ -390,25 +400,25 @@ async def main(): print("║" + " " * 78 + "║") print("╚" + "=" * 78 + "╝") print("\n") - + # Configuration check print("Configuration:") print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}") print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}") print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}") - + # Verify LLM key is set - if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here': + if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here": print("\n⚠ WARNING: LLM_API_KEY not set in .env file") print(" Set your API key to run searches") return - + print("\n") - + try: # Step 1: Setup await setup_knowledge_base() - + # Step 2: Simulate searches # Note: Repeat queries increase frequency for those topics queries = [ @@ -422,25 +432,22 @@ async def main(): "What is reinforcement learning?", "Tell me more about neural networks", # Third repeat ] - + successful_searches = await simulate_user_searches(queries) - + if successful_searches == 0: print("⚠ No searches completed - cannot demonstrate frequency tracking") return - + # Step 3: Extract frequencies - stats = await extract_and_apply_frequencies( - time_window_days=7, - min_threshold=1 - ) - + stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1) + # Step 4: Analyze results await analyze_results(stats) - + # Step 5: Show usage examples await demonstrate_retrieval_usage() - + # Summary print("\n") print("╔" + "=" * 78 + "╗") @@ -449,26 +456,27 @@ async def main(): print("║" + " " * 78 + "║") print("╚" + "=" * 78 + "╝") print("\n") - + print("Summary:") - print(f" ✓ Documents added: 4") + print(" ✓ Documents added: 4") print(f" ✓ Searches performed: {successful_searches}") print(f" ✓ Interactions tracked: {stats['interactions_in_window']}") print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}") - + print("\nNext steps:") print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph") print(" 2. Modify retrieval strategies to use frequency_weight") print(" 3. Build analytics dashboards using element_type_frequencies") print(" 4. Run periodic frequency updates to track trends over time") - + print("\n") - + except Exception as e: print(f"\n✗ Example failed: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())