chore: removed inconsistency in node properties btw task, e2e example and test codes
This commit is contained in:
parent
931c5f3096
commit
e0c7e68dd6
3 changed files with 1141 additions and 105 deletions
|
|
@ -1,8 +1,12 @@
|
||||||
# cognee/tasks/memify/extract_usage_frequency.py
|
from typing import List, Dict, Any, Optional
|
||||||
from typing import List, Dict, Any
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.pipelines.tasks.task import Task
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
|
|
||||||
|
logger = get_logger("extract_usage_frequency")
|
||||||
|
|
||||||
|
|
||||||
async def extract_usage_frequency(
|
async def extract_usage_frequency(
|
||||||
subgraphs: List[CogneeGraph],
|
subgraphs: List[CogneeGraph],
|
||||||
|
|
@ -10,35 +14,93 @@ async def extract_usage_frequency(
|
||||||
min_interaction_threshold: int = 1
|
min_interaction_threshold: int = 1
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Extract usage frequency from CogneeUserInteraction nodes
|
Extract usage frequency from CogneeUserInteraction nodes.
|
||||||
|
|
||||||
:param subgraphs: List of graph subgraphs
|
When save_interaction=True in cognee.search(), the system creates:
|
||||||
:param time_window: Time window to consider for interactions
|
- CogneeUserInteraction nodes (representing the query/answer interaction)
|
||||||
:param min_interaction_threshold: Minimum interactions to track
|
- used_graph_element_to_answer edges (connecting interactions to graph elements used)
|
||||||
:return: Dictionary of usage frequencies
|
|
||||||
|
This function tallies how often each graph element is referenced via these edges,
|
||||||
|
enabling frequency-based ranking in downstream retrievers.
|
||||||
|
|
||||||
|
:param subgraphs: List of CogneeGraph instances containing interaction data
|
||||||
|
:param time_window: Time window to consider for interactions (default: 7 days)
|
||||||
|
:param min_interaction_threshold: Minimum interactions to track (default: 1)
|
||||||
|
:return: Dictionary containing node frequencies, edge frequencies, and metadata
|
||||||
"""
|
"""
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
|
cutoff_time = current_time - time_window
|
||||||
|
|
||||||
|
# Track frequencies for graph elements (nodes and edges)
|
||||||
node_frequencies = {}
|
node_frequencies = {}
|
||||||
edge_frequencies = {}
|
edge_frequencies = {}
|
||||||
|
relationship_type_frequencies = {}
|
||||||
|
|
||||||
|
# Track interaction metadata
|
||||||
|
interaction_count = 0
|
||||||
|
interactions_in_window = 0
|
||||||
|
|
||||||
|
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
|
||||||
|
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
|
||||||
|
|
||||||
for subgraph in subgraphs:
|
for subgraph in subgraphs:
|
||||||
# Filter CogneeUserInteraction nodes within time window
|
# Find all CogneeUserInteraction nodes
|
||||||
user_interactions = [
|
interaction_nodes = {}
|
||||||
interaction for interaction in subgraph.nodes
|
for node_id, node in subgraph.nodes.items():
|
||||||
if (interaction.get('type') == 'CogneeUserInteraction' and
|
node_type = node.attributes.get('type') or node.attributes.get('node_type')
|
||||||
current_time - datetime.fromisoformat(interaction.get('timestamp', current_time.isoformat())) <= time_window)
|
|
||||||
]
|
if node_type == 'CogneeUserInteraction':
|
||||||
|
# Parse and validate timestamp
|
||||||
|
timestamp_str = node.attributes.get('timestamp') or node.attributes.get('created_at')
|
||||||
|
if timestamp_str:
|
||||||
|
try:
|
||||||
|
interaction_time = datetime.fromisoformat(timestamp_str)
|
||||||
|
interaction_nodes[node_id] = {
|
||||||
|
'node': node,
|
||||||
|
'timestamp': interaction_time,
|
||||||
|
'in_window': interaction_time >= cutoff_time
|
||||||
|
}
|
||||||
|
interaction_count += 1
|
||||||
|
if interaction_time >= cutoff_time:
|
||||||
|
interactions_in_window += 1
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}")
|
||||||
|
|
||||||
# Count node and edge frequencies
|
# Process edges to find graph elements used in interactions
|
||||||
for interaction in user_interactions:
|
for edge in subgraph.edges:
|
||||||
target_node_id = interaction.get('target_node_id')
|
relationship_type = edge.attributes.get('relationship_type')
|
||||||
edge_type = interaction.get('edge_type')
|
|
||||||
|
|
||||||
if target_node_id:
|
# Look for 'used_graph_element_to_answer' edges
|
||||||
node_frequencies[target_node_id] = node_frequencies.get(target_node_id, 0) + 1
|
if relationship_type == 'used_graph_element_to_answer':
|
||||||
|
# node1 should be the CogneeUserInteraction, node2 is the graph element
|
||||||
|
source_id = str(edge.node1.id)
|
||||||
|
target_id = str(edge.node2.id)
|
||||||
|
|
||||||
|
# Check if source is an interaction node in our time window
|
||||||
|
if source_id in interaction_nodes:
|
||||||
|
interaction_data = interaction_nodes[source_id]
|
||||||
|
|
||||||
|
if interaction_data['in_window']:
|
||||||
|
# Count the graph element (target node) being used
|
||||||
|
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
|
||||||
|
|
||||||
|
# Also track what type of element it is for analytics
|
||||||
|
target_node = subgraph.get_node(target_id)
|
||||||
|
if target_node:
|
||||||
|
element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type')
|
||||||
|
if element_type:
|
||||||
|
relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1
|
||||||
|
|
||||||
if edge_type:
|
# Also track general edge usage patterns
|
||||||
edge_frequencies[edge_type] = edge_frequencies.get(edge_type, 0) + 1
|
elif relationship_type and relationship_type != 'used_graph_element_to_answer':
|
||||||
|
# Check if either endpoint is referenced in a recent interaction
|
||||||
|
source_id = str(edge.node1.id)
|
||||||
|
target_id = str(edge.node2.id)
|
||||||
|
|
||||||
|
# If this edge connects to any frequently accessed nodes, track the edge type
|
||||||
|
if source_id in node_frequencies or target_id in node_frequencies:
|
||||||
|
edge_key = f"{relationship_type}:{source_id}:{target_id}"
|
||||||
|
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
|
||||||
|
|
||||||
# Filter frequencies above threshold
|
# Filter frequencies above threshold
|
||||||
filtered_node_frequencies = {
|
filtered_node_frequencies = {
|
||||||
|
|
@ -47,55 +109,292 @@ async def extract_usage_frequency(
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered_edge_frequencies = {
|
filtered_edge_frequencies = {
|
||||||
edge_type: freq for edge_type, freq in edge_frequencies.items()
|
edge_key: freq for edge_key, freq in edge_frequencies.items()
|
||||||
if freq >= min_interaction_threshold
|
if freq >= min_interaction_threshold
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(filtered_node_frequencies)} nodes and {len(filtered_edge_frequencies)} edges "
|
||||||
|
f"above threshold (min: {min_interaction_threshold})"
|
||||||
|
)
|
||||||
|
logger.info(f"Element type distribution: {relationship_type_frequencies}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'node_frequencies': filtered_node_frequencies,
|
'node_frequencies': filtered_node_frequencies,
|
||||||
'edge_frequencies': filtered_edge_frequencies,
|
'edge_frequencies': filtered_edge_frequencies,
|
||||||
'last_processed_timestamp': current_time.isoformat()
|
'element_type_frequencies': relationship_type_frequencies,
|
||||||
|
'total_interactions': interaction_count,
|
||||||
|
'interactions_in_window': interactions_in_window,
|
||||||
|
'time_window_days': time_window.days,
|
||||||
|
'last_processed_timestamp': current_time.isoformat(),
|
||||||
|
'cutoff_timestamp': cutoff_time.isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def add_frequency_weights(
|
async def add_frequency_weights(
|
||||||
graph_adapter,
|
graph_adapter: GraphDBInterface,
|
||||||
usage_frequencies: Dict[str, Any]
|
usage_frequencies: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add frequency weights to graph nodes and edges
|
Add frequency weights to graph nodes and edges using the graph adapter.
|
||||||
|
|
||||||
:param graph_adapter: Graph database adapter
|
Uses the "get → tweak dict → update" contract consistent with graph adapters.
|
||||||
:param usage_frequencies: Calculated usage frequencies
|
Writes frequency_weight properties back to the graph for use in:
|
||||||
|
- Ranking frequently referenced entities higher during retrieval
|
||||||
|
- Adjusting scoring for completion strategies
|
||||||
|
- Exposing usage metrics in dashboards or audits
|
||||||
|
|
||||||
|
:param graph_adapter: Graph database adapter interface
|
||||||
|
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
|
||||||
"""
|
"""
|
||||||
# Update node frequencies
|
node_frequencies = usage_frequencies.get('node_frequencies', {})
|
||||||
for node_id, frequency in usage_frequencies['node_frequencies'].items():
|
edge_frequencies = usage_frequencies.get('edge_frequencies', {})
|
||||||
|
|
||||||
|
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
|
||||||
|
|
||||||
|
# Update node frequencies using get → tweak → update pattern
|
||||||
|
nodes_updated = 0
|
||||||
|
nodes_failed = 0
|
||||||
|
|
||||||
|
for node_id, frequency in node_frequencies.items():
|
||||||
try:
|
try:
|
||||||
node = graph_adapter.get_node(node_id)
|
# Get current node data
|
||||||
if node:
|
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||||
node_properties = node.get_properties() or {}
|
|
||||||
node_properties['frequency_weight'] = frequency
|
if node_data:
|
||||||
graph_adapter.update_node(node_id, node_properties)
|
# Tweak the properties dict - add frequency_weight
|
||||||
|
if isinstance(node_data, dict):
|
||||||
|
properties = node_data.get('properties', {})
|
||||||
|
else:
|
||||||
|
# Handle case where node_data might be a node object
|
||||||
|
properties = getattr(node_data, 'properties', {}) or {}
|
||||||
|
|
||||||
|
# Update with frequency weight
|
||||||
|
properties['frequency_weight'] = frequency
|
||||||
|
|
||||||
|
# Also store when this was last updated
|
||||||
|
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
|
||||||
|
|
||||||
|
# Write back via adapter
|
||||||
|
await graph_adapter.update_node_properties(node_id, properties)
|
||||||
|
nodes_updated += 1
|
||||||
|
else:
|
||||||
|
logger.warning(f"Node {node_id} not found in graph")
|
||||||
|
nodes_failed += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating node {node_id}: {e}")
|
logger.error(f"Error updating node {node_id}: {e}")
|
||||||
|
nodes_failed += 1
|
||||||
|
|
||||||
# Note: Edge frequency update might require backend-specific implementation
|
logger.info(
|
||||||
print("Edge frequency update might need backend-specific handling")
|
f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update edge frequencies
|
||||||
|
# Note: Edge property updates are backend-specific
|
||||||
|
if edge_frequencies:
|
||||||
|
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
|
||||||
|
|
||||||
|
edges_updated = 0
|
||||||
|
edges_failed = 0
|
||||||
|
|
||||||
|
for edge_key, frequency in edge_frequencies.items():
|
||||||
|
try:
|
||||||
|
# Parse edge key: "relationship_type:source_id:target_id"
|
||||||
|
parts = edge_key.split(':', 2)
|
||||||
|
if len(parts) == 3:
|
||||||
|
relationship_type, source_id, target_id = parts
|
||||||
|
|
||||||
|
# Try to update edge if adapter supports it
|
||||||
|
if hasattr(graph_adapter, 'update_edge_properties'):
|
||||||
|
edge_properties = {
|
||||||
|
'frequency_weight': frequency,
|
||||||
|
'frequency_updated_at': usage_frequencies.get('last_processed_timestamp')
|
||||||
|
}
|
||||||
|
|
||||||
|
await graph_adapter.update_edge_properties(
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
relationship_type,
|
||||||
|
edge_properties
|
||||||
|
)
|
||||||
|
edges_updated += 1
|
||||||
|
else:
|
||||||
|
# Fallback: store in metadata or log
|
||||||
|
logger.debug(
|
||||||
|
f"Adapter doesn't support update_edge_properties for "
|
||||||
|
f"{relationship_type} ({source_id} -> {target_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating edge {edge_key}: {e}")
|
||||||
|
edges_failed += 1
|
||||||
|
|
||||||
|
if edges_updated > 0:
|
||||||
|
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Edge frequency updates skipped (adapter may not support edge property updates)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store aggregate statistics as metadata if supported
|
||||||
|
if hasattr(graph_adapter, 'set_metadata'):
|
||||||
|
try:
|
||||||
|
metadata = {
|
||||||
|
'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}),
|
||||||
|
'total_interactions': usage_frequencies.get('total_interactions', 0),
|
||||||
|
'interactions_in_window': usage_frequencies.get('interactions_in_window', 0),
|
||||||
|
'last_frequency_update': usage_frequencies.get('last_processed_timestamp')
|
||||||
|
}
|
||||||
|
await graph_adapter.set_metadata('usage_frequency_stats', metadata)
|
||||||
|
logger.info("Stored usage frequency statistics as metadata")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not store usage statistics as metadata: {e}")
|
||||||
|
|
||||||
def usage_frequency_pipeline_entry(graph_adapter):
|
|
||||||
|
async def create_usage_frequency_pipeline(
|
||||||
|
graph_adapter: GraphDBInterface,
|
||||||
|
time_window: timedelta = timedelta(days=7),
|
||||||
|
min_interaction_threshold: int = 1,
|
||||||
|
batch_size: int = 100
|
||||||
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Memify pipeline entry for usage frequency tracking
|
Create memify pipeline entry for usage frequency tracking.
|
||||||
|
|
||||||
|
This follows the same pattern as feedback enrichment flows, allowing
|
||||||
|
the frequency update to run end-to-end in a custom memify pipeline.
|
||||||
|
|
||||||
|
Use case example:
|
||||||
|
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||||
|
graph_adapter=my_adapter,
|
||||||
|
time_window=timedelta(days=30),
|
||||||
|
min_interaction_threshold=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run in memify pipeline
|
||||||
|
pipeline = Pipeline(extraction_tasks + enrichment_tasks)
|
||||||
|
results = await pipeline.run()
|
||||||
|
|
||||||
:param graph_adapter: Graph database adapter
|
:param graph_adapter: Graph database adapter
|
||||||
:return: Usage frequency results
|
:param time_window: Time window for counting interactions (default: 7 days)
|
||||||
|
:param min_interaction_threshold: Minimum interactions to track (default: 1)
|
||||||
|
:param batch_size: Batch size for processing (default: 100)
|
||||||
|
:return: Tuple of (extraction_tasks, enrichment_tasks)
|
||||||
"""
|
"""
|
||||||
|
logger.info("Creating usage frequency pipeline")
|
||||||
|
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
|
||||||
|
|
||||||
extraction_tasks = [
|
extraction_tasks = [
|
||||||
Task(extract_usage_frequency,
|
Task(
|
||||||
time_window=timedelta(days=7),
|
extract_usage_frequency,
|
||||||
min_interaction_threshold=1)
|
time_window=time_window,
|
||||||
|
min_interaction_threshold=min_interaction_threshold
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
enrichment_tasks = [
|
enrichment_tasks = [
|
||||||
Task(add_frequency_weights, task_config={"batch_size": 1})
|
Task(
|
||||||
|
add_frequency_weights,
|
||||||
|
graph_adapter=graph_adapter,
|
||||||
|
task_config={"batch_size": batch_size}
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
return extraction_tasks, enrichment_tasks
|
return extraction_tasks, enrichment_tasks
|
||||||
|
|
||||||
|
|
||||||
|
async def run_usage_frequency_update(
|
||||||
|
graph_adapter: GraphDBInterface,
|
||||||
|
subgraphs: List[CogneeGraph],
|
||||||
|
time_window: timedelta = timedelta(days=7),
|
||||||
|
min_interaction_threshold: int = 1
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convenience function to run the complete usage frequency update pipeline.
|
||||||
|
|
||||||
|
This is the main entry point for updating frequency weights on graph elements
|
||||||
|
based on CogneeUserInteraction data from cognee.search(save_interaction=True).
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# After running searches with save_interaction=True
|
||||||
|
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
|
||||||
|
|
||||||
|
# Get the graph with interactions
|
||||||
|
graph = await get_cognee_graph_with_interactions()
|
||||||
|
|
||||||
|
# Update frequency weights
|
||||||
|
stats = await run_usage_frequency_update(
|
||||||
|
graph_adapter=graph_adapter,
|
||||||
|
subgraphs=[graph],
|
||||||
|
time_window=timedelta(days=30), # Last 30 days
|
||||||
|
min_interaction_threshold=2 # At least 2 uses
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Updated {len(stats['node_frequencies'])} nodes")
|
||||||
|
|
||||||
|
:param graph_adapter: Graph database adapter
|
||||||
|
:param subgraphs: List of CogneeGraph instances with interaction data
|
||||||
|
:param time_window: Time window for counting interactions
|
||||||
|
:param min_interaction_threshold: Minimum interactions to track
|
||||||
|
:return: Usage frequency statistics
|
||||||
|
"""
|
||||||
|
logger.info("Starting usage frequency update")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract frequencies from interaction data
|
||||||
|
usage_frequencies = await extract_usage_frequency(
|
||||||
|
subgraphs=subgraphs,
|
||||||
|
time_window=time_window,
|
||||||
|
min_interaction_threshold=min_interaction_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add frequency weights back to the graph
|
||||||
|
await add_frequency_weights(
|
||||||
|
graph_adapter=graph_adapter,
|
||||||
|
usage_frequencies=usage_frequencies
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Usage frequency update completed successfully")
|
||||||
|
logger.info(
|
||||||
|
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
|
||||||
|
f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage_frequencies
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during usage frequency update: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def get_most_frequent_elements(
|
||||||
|
graph_adapter: GraphDBInterface,
|
||||||
|
top_n: int = 10,
|
||||||
|
element_type: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve the most frequently accessed graph elements.
|
||||||
|
|
||||||
|
Useful for analytics dashboards and understanding user behavior.
|
||||||
|
|
||||||
|
:param graph_adapter: Graph database adapter
|
||||||
|
:param top_n: Number of top elements to return
|
||||||
|
:param element_type: Optional filter by element type
|
||||||
|
:return: List of elements with their frequency weights
|
||||||
|
"""
|
||||||
|
logger.info(f"Retrieving top {top_n} most frequent elements")
|
||||||
|
|
||||||
|
# This would need to be implemented based on the specific graph adapter's query capabilities
|
||||||
|
# Pseudocode:
|
||||||
|
# results = await graph_adapter.query_nodes_by_property(
|
||||||
|
# property_name='frequency_weight',
|
||||||
|
# order_by='DESC',
|
||||||
|
# limit=top_n,
|
||||||
|
# filters={'type': element_type} if element_type else None
|
||||||
|
# )
|
||||||
|
|
||||||
|
logger.warning("get_most_frequent_elements needs adapter-specific implementation")
|
||||||
|
return []
|
||||||
|
|
@ -1,42 +1,503 @@
|
||||||
# cognee/tests/test_usage_frequency.py
|
# cognee/tests/test_usage_frequency.py
|
||||||
|
"""
|
||||||
|
Test suite for usage frequency tracking functionality.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Frequency extraction from CogneeUserInteraction nodes
|
||||||
|
- Time window filtering
|
||||||
|
- Frequency weight application to graph
|
||||||
|
- Edge cases and error handling
|
||||||
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from cognee.tasks.memify.extract_usage_frequency import extract_usage_frequency, add_frequency_weights
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from cognee.tasks.memify.extract_usage_frequency import (
|
||||||
|
extract_usage_frequency,
|
||||||
|
add_frequency_weights,
|
||||||
|
create_usage_frequency_pipeline,
|
||||||
|
run_usage_frequency_update,
|
||||||
|
)
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_node(node_id: str, attributes: Dict[str, Any]) -> Node:
|
||||||
|
"""Helper to create mock Node objects."""
|
||||||
|
node = Node(node_id, attributes)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_edge(node1: Node, node2: Node, relationship_type: str, attributes: Dict[str, Any] = None) -> Edge:
|
||||||
|
"""Helper to create mock Edge objects."""
|
||||||
|
edge_attrs = attributes or {}
|
||||||
|
edge_attrs['relationship_type'] = relationship_type
|
||||||
|
edge = Edge(node1, node2, attributes=edge_attrs, directed=True)
|
||||||
|
return edge
|
||||||
|
|
||||||
|
|
||||||
|
def create_interaction_graph(
|
||||||
|
interaction_count: int = 3,
|
||||||
|
target_nodes: list = None,
|
||||||
|
time_offset_hours: int = 0
|
||||||
|
) -> CogneeGraph:
|
||||||
|
"""
|
||||||
|
Create a mock CogneeGraph with interaction nodes.
|
||||||
|
|
||||||
|
:param interaction_count: Number of interactions to create
|
||||||
|
:param target_nodes: List of target node IDs to reference
|
||||||
|
:param time_offset_hours: Hours to offset timestamp (negative = past)
|
||||||
|
:return: CogneeGraph with mocked interaction data
|
||||||
|
"""
|
||||||
|
graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
if target_nodes is None:
|
||||||
|
target_nodes = ['node1', 'node2', 'node3']
|
||||||
|
|
||||||
|
# Create some target graph element nodes
|
||||||
|
element_nodes = {}
|
||||||
|
for i, node_id in enumerate(target_nodes):
|
||||||
|
element_node = create_mock_node(
|
||||||
|
node_id,
|
||||||
|
{
|
||||||
|
'type': 'DocumentChunk',
|
||||||
|
'text': f'This is content for {node_id}',
|
||||||
|
'name': f'Element {i+1}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_node(element_node)
|
||||||
|
element_nodes[node_id] = element_node
|
||||||
|
|
||||||
|
# Create interaction nodes and edges
|
||||||
|
timestamp = datetime.now() + timedelta(hours=time_offset_hours)
|
||||||
|
|
||||||
|
for i in range(interaction_count):
|
||||||
|
# Create interaction node
|
||||||
|
interaction_id = f'interaction_{i}'
|
||||||
|
target_id = target_nodes[i % len(target_nodes)]
|
||||||
|
|
||||||
|
interaction_node = create_mock_node(
|
||||||
|
interaction_id,
|
||||||
|
{
|
||||||
|
'type': 'CogneeUserInteraction',
|
||||||
|
'timestamp': timestamp.isoformat(),
|
||||||
|
'query_text': f'Sample query {i}',
|
||||||
|
'target_node_id': target_id # Also store in attributes for completeness
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_node(interaction_node)
|
||||||
|
|
||||||
|
# Create edge from interaction to target element
|
||||||
|
target_element = element_nodes[target_id]
|
||||||
|
edge = create_mock_edge(
|
||||||
|
interaction_node,
|
||||||
|
target_element,
|
||||||
|
'used_graph_element_to_answer',
|
||||||
|
{'timestamp': timestamp.isoformat()}
|
||||||
|
)
|
||||||
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_usage_frequency():
|
async def test_extract_usage_frequency_basic():
|
||||||
# Mock CogneeGraph with user interactions
|
"""Test basic frequency extraction with simple interaction data."""
|
||||||
mock_subgraphs = [{
|
# Create mock graph with 3 interactions
|
||||||
'nodes': [
|
# node1 referenced twice, node2 referenced once
|
||||||
{
|
mock_graph = create_interaction_graph(
|
||||||
'type': 'CogneeUserInteraction',
|
interaction_count=3,
|
||||||
'target_node_id': 'node1',
|
target_nodes=['node1', 'node1', 'node2']
|
||||||
'edge_type': 'viewed',
|
)
|
||||||
'timestamp': datetime.now().isoformat()
|
|
||||||
},
|
# Extract frequencies
|
||||||
{
|
|
||||||
'type': 'CogneeUserInteraction',
|
|
||||||
'target_node_id': 'node1',
|
|
||||||
'edge_type': 'viewed',
|
|
||||||
'timestamp': datetime.now().isoformat()
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'CogneeUserInteraction',
|
|
||||||
'target_node_id': 'node2',
|
|
||||||
'edge_type': 'referenced',
|
|
||||||
'timestamp': datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Test frequency extraction
|
|
||||||
result = await extract_usage_frequency(
|
result = await extract_usage_frequency(
|
||||||
mock_subgraphs,
|
subgraphs=[mock_graph],
|
||||||
time_window=timedelta(days=1),
|
time_window=timedelta(days=1),
|
||||||
min_interaction_threshold=1
|
min_interaction_threshold=1
|
||||||
)
|
)
|
||||||
|
|
||||||
assert 'node1' in result['node_frequencies']
|
# Assertions
|
||||||
|
assert 'node_frequencies' in result
|
||||||
|
assert 'edge_frequencies' in result
|
||||||
assert result['node_frequencies']['node1'] == 2
|
assert result['node_frequencies']['node1'] == 2
|
||||||
assert result['edge_frequencies']['viewed'] == 2
|
assert result['node_frequencies']['node2'] == 1
|
||||||
|
assert result['total_interactions'] == 3
|
||||||
|
assert result['interactions_in_window'] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_time_window():
|
||||||
|
"""Test that time window filtering works correctly."""
|
||||||
|
# Create two graphs: one recent, one old
|
||||||
|
recent_graph = create_interaction_graph(
|
||||||
|
interaction_count=2,
|
||||||
|
target_nodes=['node1', 'node2'],
|
||||||
|
time_offset_hours=-1 # 1 hour ago
|
||||||
|
)
|
||||||
|
|
||||||
|
old_graph = create_interaction_graph(
|
||||||
|
interaction_count=2,
|
||||||
|
target_nodes=['node3', 'node4'],
|
||||||
|
time_offset_hours=-200 # 200 hours ago (> 7 days)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract with 7-day window
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[recent_graph, old_graph],
|
||||||
|
time_window=timedelta(days=7),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only recent interactions should be counted
|
||||||
|
assert result['total_interactions'] == 4 # All interactions found
|
||||||
|
assert result['interactions_in_window'] == 2 # Only recent ones counted
|
||||||
|
assert 'node1' in result['node_frequencies']
|
||||||
|
assert 'node2' in result['node_frequencies']
|
||||||
|
assert 'node3' not in result['node_frequencies'] # Too old
|
||||||
|
assert 'node4' not in result['node_frequencies'] # Too old
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_threshold():
|
||||||
|
"""Test minimum interaction threshold filtering."""
|
||||||
|
# Create graph where node1 has 3 interactions, node2 has 1
|
||||||
|
mock_graph = create_interaction_graph(
|
||||||
|
interaction_count=4,
|
||||||
|
target_nodes=['node1', 'node1', 'node1', 'node2']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract with threshold of 2
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[mock_graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only node1 should be in results (3 >= 2)
|
||||||
|
assert 'node1' in result['node_frequencies']
|
||||||
|
assert result['node_frequencies']['node1'] == 3
|
||||||
|
assert 'node2' not in result['node_frequencies'] # Below threshold
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_multiple_graphs():
|
||||||
|
"""Test extraction across multiple subgraphs."""
|
||||||
|
graph1 = create_interaction_graph(
|
||||||
|
interaction_count=2,
|
||||||
|
target_nodes=['node1', 'node2']
|
||||||
|
)
|
||||||
|
|
||||||
|
graph2 = create_interaction_graph(
|
||||||
|
interaction_count=2,
|
||||||
|
target_nodes=['node1', 'node3']
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[graph1, graph2],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# node1 should have frequency of 2 (once from each graph)
|
||||||
|
assert result['node_frequencies']['node1'] == 2
|
||||||
|
assert result['node_frequencies']['node2'] == 1
|
||||||
|
assert result['node_frequencies']['node3'] == 1
|
||||||
|
assert result['total_interactions'] == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_empty_graph():
|
||||||
|
"""Test handling of empty graphs."""
|
||||||
|
empty_graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[empty_graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result['node_frequencies'] == {}
|
||||||
|
assert result['edge_frequencies'] == {}
|
||||||
|
assert result['total_interactions'] == 0
|
||||||
|
assert result['interactions_in_window'] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_invalid_timestamps():
|
||||||
|
"""Test handling of invalid timestamp formats."""
|
||||||
|
graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
# Create interaction with invalid timestamp
|
||||||
|
bad_interaction = create_mock_node(
|
||||||
|
'bad_interaction',
|
||||||
|
{
|
||||||
|
'type': 'CogneeUserInteraction',
|
||||||
|
'timestamp': 'not-a-valid-timestamp',
|
||||||
|
'target_node_id': 'node1'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_node(bad_interaction)
|
||||||
|
|
||||||
|
# Should not crash, just skip invalid interaction
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result['total_interactions'] == 0 # Invalid interaction not counted
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_element_type_tracking():
|
||||||
|
"""Test that element type frequencies are tracked."""
|
||||||
|
graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
# Create different types of target nodes
|
||||||
|
chunk_node = create_mock_node('chunk1', {'type': 'DocumentChunk', 'text': 'content'})
|
||||||
|
entity_node = create_mock_node('entity1', {'type': 'Entity', 'name': 'Alice'})
|
||||||
|
|
||||||
|
graph.add_node(chunk_node)
|
||||||
|
graph.add_node(entity_node)
|
||||||
|
|
||||||
|
# Create interactions pointing to each
|
||||||
|
timestamp = datetime.now().isoformat()
|
||||||
|
|
||||||
|
for i, target in enumerate([chunk_node, chunk_node, entity_node]):
|
||||||
|
interaction = create_mock_node(
|
||||||
|
f'interaction_{i}',
|
||||||
|
{'type': 'CogneeUserInteraction', 'timestamp': timestamp}
|
||||||
|
)
|
||||||
|
graph.add_node(interaction)
|
||||||
|
|
||||||
|
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
|
||||||
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check element type frequencies
|
||||||
|
assert 'element_type_frequencies' in result
|
||||||
|
assert result['element_type_frequencies']['DocumentChunk'] == 2
|
||||||
|
assert result['element_type_frequencies']['Entity'] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_frequency_weights():
|
||||||
|
"""Test adding frequency weights to graph via adapter."""
|
||||||
|
# Mock graph adapter
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.get_node_by_id = AsyncMock(return_value={
|
||||||
|
'id': 'node1',
|
||||||
|
'properties': {'type': 'DocumentChunk', 'text': 'content'}
|
||||||
|
})
|
||||||
|
mock_adapter.update_node_properties = AsyncMock()
|
||||||
|
|
||||||
|
# Mock usage frequencies
|
||||||
|
usage_frequencies = {
|
||||||
|
'node_frequencies': {'node1': 5, 'node2': 3},
|
||||||
|
'edge_frequencies': {},
|
||||||
|
'last_processed_timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add weights
|
||||||
|
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||||
|
|
||||||
|
# Verify adapter methods were called
|
||||||
|
assert mock_adapter.get_node_by_id.call_count == 2
|
||||||
|
assert mock_adapter.update_node_properties.call_count == 2
|
||||||
|
|
||||||
|
# Verify the properties passed to update include frequency_weight
|
||||||
|
calls = mock_adapter.update_node_properties.call_args_list
|
||||||
|
properties_updated = calls[0][0][1] # Second argument of first call
|
||||||
|
assert 'frequency_weight' in properties_updated
|
||||||
|
assert properties_updated['frequency_weight'] == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_frequency_weights_node_not_found():
|
||||||
|
"""Test handling when node is not found in graph."""
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.get_node_by_id = AsyncMock(return_value=None) # Node not found
|
||||||
|
mock_adapter.update_node_properties = AsyncMock()
|
||||||
|
|
||||||
|
usage_frequencies = {
|
||||||
|
'node_frequencies': {'nonexistent_node': 5},
|
||||||
|
'edge_frequencies': {},
|
||||||
|
'last_processed_timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should not crash
|
||||||
|
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||||
|
|
||||||
|
# Update should not be called since node wasn't found
|
||||||
|
assert mock_adapter.update_node_properties.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_frequency_weights_with_metadata_support():
|
||||||
|
"""Test that metadata is stored when adapter supports it."""
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
|
||||||
|
mock_adapter.update_node_properties = AsyncMock()
|
||||||
|
mock_adapter.set_metadata = AsyncMock() # Adapter supports metadata
|
||||||
|
|
||||||
|
usage_frequencies = {
|
||||||
|
'node_frequencies': {'node1': 5},
|
||||||
|
'edge_frequencies': {},
|
||||||
|
'element_type_frequencies': {'DocumentChunk': 5},
|
||||||
|
'total_interactions': 10,
|
||||||
|
'interactions_in_window': 8,
|
||||||
|
'last_processed_timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||||
|
|
||||||
|
# Verify metadata was stored
|
||||||
|
mock_adapter.set_metadata.assert_called_once()
|
||||||
|
metadata_key, metadata_value = mock_adapter.set_metadata.call_args[0]
|
||||||
|
assert metadata_key == 'usage_frequency_stats'
|
||||||
|
assert 'total_interactions' in metadata_value
|
||||||
|
assert metadata_value['total_interactions'] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_usage_frequency_pipeline():
|
||||||
|
"""Test pipeline creation returns correct task structure."""
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
|
||||||
|
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||||
|
graph_adapter=mock_adapter,
|
||||||
|
time_window=timedelta(days=7),
|
||||||
|
min_interaction_threshold=2,
|
||||||
|
batch_size=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify task structure
|
||||||
|
assert len(extraction_tasks) == 1
|
||||||
|
assert len(enrichment_tasks) == 1
|
||||||
|
|
||||||
|
# Verify extraction task
|
||||||
|
extraction_task = extraction_tasks[0]
|
||||||
|
assert hasattr(extraction_task, 'function')
|
||||||
|
|
||||||
|
# Verify enrichment task
|
||||||
|
enrichment_task = enrichment_tasks[0]
|
||||||
|
assert hasattr(enrichment_task, 'function')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_usage_frequency_update_integration():
|
||||||
|
"""Test the full end-to-end update process."""
|
||||||
|
# Create mock graph with interactions
|
||||||
|
mock_graph = create_interaction_graph(
|
||||||
|
interaction_count=5,
|
||||||
|
target_nodes=['node1', 'node1', 'node2', 'node3', 'node1']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock adapter
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
|
||||||
|
mock_adapter.update_node_properties = AsyncMock()
|
||||||
|
|
||||||
|
# Run the full update
|
||||||
|
stats = await run_usage_frequency_update(
|
||||||
|
graph_adapter=mock_adapter,
|
||||||
|
subgraphs=[mock_graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify stats
|
||||||
|
assert stats['total_interactions'] == 5
|
||||||
|
assert stats['node_frequencies']['node1'] == 3
|
||||||
|
assert stats['node_frequencies']['node2'] == 1
|
||||||
|
assert stats['node_frequencies']['node3'] == 1
|
||||||
|
|
||||||
|
# Verify adapter was called to update nodes
|
||||||
|
assert mock_adapter.update_node_properties.call_count == 3 # 3 unique nodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_no_used_graph_element_edges():
|
||||||
|
"""Test handling when there are interactions but no proper edges."""
|
||||||
|
graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
# Create interaction node
|
||||||
|
interaction = create_mock_node(
|
||||||
|
'interaction1',
|
||||||
|
{
|
||||||
|
'type': 'CogneeUserInteraction',
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'target_node_id': 'node1'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_node(interaction)
|
||||||
|
|
||||||
|
# Don't add any edges - interaction is orphaned
|
||||||
|
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should find the interaction but no frequencies (no edges)
|
||||||
|
assert result['total_interactions'] == 1
|
||||||
|
assert result['node_frequencies'] == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_usage_frequency_alternative_timestamp_field():
|
||||||
|
"""Test that 'created_at' field works as fallback for timestamp."""
|
||||||
|
graph = CogneeGraph(directed=True)
|
||||||
|
|
||||||
|
target = create_mock_node('target1', {'type': 'DocumentChunk'})
|
||||||
|
graph.add_node(target)
|
||||||
|
|
||||||
|
# Use 'created_at' instead of 'timestamp'
|
||||||
|
interaction = create_mock_node(
|
||||||
|
'interaction1',
|
||||||
|
{
|
||||||
|
'type': 'CogneeUserInteraction',
|
||||||
|
'created_at': datetime.now().isoformat() # Alternative field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_node(interaction)
|
||||||
|
|
||||||
|
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
|
||||||
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
result = await extract_usage_frequency(
|
||||||
|
subgraphs=[graph],
|
||||||
|
time_window=timedelta(days=1),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still work with created_at
|
||||||
|
assert result['total_interactions'] == 1
|
||||||
|
assert 'target1' in result['node_frequencies']
|
||||||
|
|
||||||
|
|
||||||
|
def test_imports():
|
||||||
|
"""Test that all required modules can be imported."""
|
||||||
|
from cognee.tasks.memify.extract_usage_frequency import (
|
||||||
|
extract_usage_frequency,
|
||||||
|
add_frequency_weights,
|
||||||
|
create_usage_frequency_pipeline,
|
||||||
|
run_usage_frequency_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extract_usage_frequency is not None
|
||||||
|
assert add_frequency_weights is not None
|
||||||
|
assert create_usage_frequency_pipeline is not None
|
||||||
|
assert run_usage_frequency_update is not None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
|
|
@ -1,49 +1,325 @@
|
||||||
# cognee/examples/usage_frequency_example.py
|
# cognee/examples/usage_frequency_example.py
|
||||||
|
"""
|
||||||
|
End-to-end example demonstrating usage frequency tracking in Cognee.
|
||||||
|
|
||||||
|
This example shows how to:
|
||||||
|
1. Add data and build a knowledge graph
|
||||||
|
2. Run searches with save_interaction=True to track usage
|
||||||
|
3. Extract and apply frequency weights using the memify pipeline
|
||||||
|
4. Query and analyze the frequency data
|
||||||
|
|
||||||
|
The frequency weights can be used to:
|
||||||
|
- Rank frequently referenced entities higher during retrieval
|
||||||
|
- Adjust scoring for completion strategies
|
||||||
|
- Expose usage metrics in dashboards or audits
|
||||||
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
from cognee.tasks.memify.extract_usage_frequency import usage_frequency_pipeline_entry
|
from cognee.tasks.memify.extract_usage_frequency import (
|
||||||
|
create_usage_frequency_pipeline,
|
||||||
|
run_usage_frequency_update,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
async def main():
|
logger = get_logger("usage_frequency_example")
|
||||||
# Reset cognee state
|
|
||||||
|
|
||||||
|
async def setup_knowledge_base():
|
||||||
|
"""Set up a fresh knowledge base with sample data."""
|
||||||
|
logger.info("Setting up knowledge base...")
|
||||||
|
|
||||||
|
# Reset cognee state for clean slate
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
# Sample conversation
|
# Sample conversation about AI/ML topics
|
||||||
conversation = [
|
conversation = [
|
||||||
"Alice discusses machine learning",
|
"Alice discusses machine learning algorithms and their applications in computer vision.",
|
||||||
"Bob asks about neural networks",
|
"Bob asks about neural networks and how they differ from traditional algorithms.",
|
||||||
"Alice explains deep learning concepts",
|
"Alice explains deep learning concepts including CNNs and transformers.",
|
||||||
"Bob wants more details about neural networks"
|
"Bob wants more details about neural networks and backpropagation.",
|
||||||
|
"Alice describes reinforcement learning and its use in robotics.",
|
||||||
|
"Bob inquires about natural language processing and transformers.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add conversation and cognify
|
# Add conversation data and build knowledge graph
|
||||||
await cognee.add(conversation)
|
logger.info("Adding conversation data...")
|
||||||
|
await cognee.add(conversation, dataset_name="ai_ml_conversation")
|
||||||
|
|
||||||
|
logger.info("Building knowledge graph (cognify)...")
|
||||||
await cognee.cognify()
|
await cognee.cognify()
|
||||||
|
|
||||||
|
logger.info("Knowledge base setup complete")
|
||||||
|
|
||||||
# Perform some searches to generate interactions
|
|
||||||
for query in ["machine learning", "neural networks", "deep learning"]:
|
async def simulate_user_searches():
|
||||||
await cognee.search(
|
"""Simulate multiple user searches to generate interaction data."""
|
||||||
|
logger.info("Simulating user searches with save_interaction=True...")
|
||||||
|
|
||||||
|
# Different queries that will create CogneeUserInteraction nodes
|
||||||
|
queries = [
|
||||||
|
"What is machine learning?",
|
||||||
|
"Explain neural networks",
|
||||||
|
"Tell me about deep learning",
|
||||||
|
"What are neural networks?", # Repeat to increase frequency
|
||||||
|
"How does machine learning work?",
|
||||||
|
"Describe transformers in NLP",
|
||||||
|
"What is reinforcement learning?",
|
||||||
|
"Explain neural networks again", # Another repeat
|
||||||
|
]
|
||||||
|
|
||||||
|
search_count = 0
|
||||||
|
for query in queries:
|
||||||
|
try:
|
||||||
|
logger.info(f"Searching: '{query}'")
|
||||||
|
results = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text=query,
|
||||||
|
save_interaction=True, # Critical: saves interaction to graph
|
||||||
|
top_k=5
|
||||||
|
)
|
||||||
|
search_count += 1
|
||||||
|
logger.debug(f"Search completed, got {len(results) if results else 0} results")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Search failed for '{query}': {e}")
|
||||||
|
|
||||||
|
logger.info(f"Completed {search_count} searches with interactions saved")
|
||||||
|
return search_count
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_interaction_graph() -> List[CogneeGraph]:
|
||||||
|
"""Retrieve the graph containing interaction nodes."""
|
||||||
|
logger.info("Retrieving graph with interaction data...")
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
graph = CogneeGraph()
|
||||||
|
|
||||||
|
# Project the full graph including CogneeUserInteraction nodes
|
||||||
|
await graph.project_graph_from_db(
|
||||||
|
adapter=graph_engine,
|
||||||
|
node_properties_to_project=["type", "node_type", "timestamp", "created_at", "text", "name"],
|
||||||
|
edge_properties_to_project=["relationship_type", "timestamp", "created_at"],
|
||||||
|
directed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Retrieved graph: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
|
||||||
|
|
||||||
|
# Count interaction nodes for verification
|
||||||
|
interaction_count = sum(
|
||||||
|
1 for node in graph.nodes.values()
|
||||||
|
if node.attributes.get('type') == 'CogneeUserInteraction' or
|
||||||
|
node.attributes.get('node_type') == 'CogneeUserInteraction'
|
||||||
|
)
|
||||||
|
logger.info(f"Found {interaction_count} CogneeUserInteraction nodes in graph")
|
||||||
|
|
||||||
|
return [graph]
|
||||||
|
|
||||||
|
|
||||||
|
async def run_frequency_pipeline_method1():
|
||||||
|
"""Method 1: Using the pipeline creation function."""
|
||||||
|
logger.info("\n=== Method 1: Using create_usage_frequency_pipeline ===")
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
subgraphs = await retrieve_interaction_graph()
|
||||||
|
|
||||||
|
# Create the pipeline tasks
|
||||||
|
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||||
|
graph_adapter=graph_engine,
|
||||||
|
time_window=timedelta(days=30), # Last 30 days
|
||||||
|
min_interaction_threshold=1, # Count all interactions
|
||||||
|
batch_size=100
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Running extraction tasks...")
|
||||||
|
# Note: In real memify pipeline, these would be executed by the pipeline runner
|
||||||
|
# For this example, we'll execute them manually
|
||||||
|
for task in extraction_tasks:
|
||||||
|
if hasattr(task, 'function'):
|
||||||
|
result = await task.function(
|
||||||
|
subgraphs=subgraphs,
|
||||||
|
time_window=timedelta(days=30),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
logger.info(f"Extraction result: {result.get('interactions_in_window')} interactions processed")
|
||||||
|
|
||||||
|
logger.info("Running enrichment tasks...")
|
||||||
|
for task in enrichment_tasks:
|
||||||
|
if hasattr(task, 'function'):
|
||||||
|
await task.function(
|
||||||
|
graph_adapter=graph_engine,
|
||||||
|
usage_frequencies=result
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def run_frequency_pipeline_method2():
|
||||||
|
"""Method 2: Using the convenience function."""
|
||||||
|
logger.info("\n=== Method 2: Using run_usage_frequency_update ===")
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
subgraphs = await retrieve_interaction_graph()
|
||||||
|
|
||||||
|
# Run the complete pipeline in one call
|
||||||
|
stats = await run_usage_frequency_update(
|
||||||
|
graph_adapter=graph_engine,
|
||||||
|
subgraphs=subgraphs,
|
||||||
|
time_window=timedelta(days=30),
|
||||||
|
min_interaction_threshold=1
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Frequency update statistics:")
|
||||||
|
logger.info(f" Total interactions: {stats['total_interactions']}")
|
||||||
|
logger.info(f" Interactions in window: {stats['interactions_in_window']}")
|
||||||
|
logger.info(f" Nodes with frequency weights: {len(stats['node_frequencies'])}")
|
||||||
|
logger.info(f" Element types: {stats.get('element_type_frequencies', {})}")
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_frequency_weights():
|
||||||
|
"""Analyze and display the frequency weights that were added."""
|
||||||
|
logger.info("\n=== Analyzing Frequency Weights ===")
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
graph = CogneeGraph()
|
||||||
|
|
||||||
|
# Project graph with frequency weights
|
||||||
|
await graph.project_graph_from_db(
|
||||||
|
adapter=graph_engine,
|
||||||
|
node_properties_to_project=[
|
||||||
|
"type",
|
||||||
|
"node_type",
|
||||||
|
"text",
|
||||||
|
"name",
|
||||||
|
"frequency_weight", # Our added property
|
||||||
|
"frequency_updated_at"
|
||||||
|
],
|
||||||
|
edge_properties_to_project=["relationship_type"],
|
||||||
|
directed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find nodes with frequency weights
|
||||||
|
weighted_nodes = []
|
||||||
|
for node_id, node in graph.nodes.items():
|
||||||
|
freq_weight = node.attributes.get('frequency_weight')
|
||||||
|
if freq_weight is not None:
|
||||||
|
weighted_nodes.append({
|
||||||
|
'id': node_id,
|
||||||
|
'type': node.attributes.get('type') or node.attributes.get('node_type'),
|
||||||
|
'text': node.attributes.get('text', '')[:100], # First 100 chars
|
||||||
|
'name': node.attributes.get('name', ''),
|
||||||
|
'frequency_weight': freq_weight,
|
||||||
|
'updated_at': node.attributes.get('frequency_updated_at')
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by frequency (descending)
|
||||||
|
weighted_nodes.sort(key=lambda x: x['frequency_weight'], reverse=True)
|
||||||
|
|
||||||
|
logger.info(f"\nFound {len(weighted_nodes)} nodes with frequency weights:")
|
||||||
|
logger.info("\nTop 10 Most Frequently Referenced Elements:")
|
||||||
|
logger.info("-" * 80)
|
||||||
|
|
||||||
|
for i, node in enumerate(weighted_nodes[:10], 1):
|
||||||
|
logger.info(f"\n{i}. Frequency: {node['frequency_weight']}")
|
||||||
|
logger.info(f" Type: {node['type']}")
|
||||||
|
logger.info(f" Name: {node['name']}")
|
||||||
|
logger.info(f" Text: {node['text']}")
|
||||||
|
logger.info(f" ID: {node['id'][:50]}...")
|
||||||
|
|
||||||
|
return weighted_nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def demonstrate_retrieval_with_frequencies():
|
||||||
|
"""Demonstrate how frequency weights can be used in retrieval."""
|
||||||
|
logger.info("\n=== Demonstrating Retrieval with Frequency Weights ===")
|
||||||
|
|
||||||
|
# This is a conceptual demonstration of how frequency weights
|
||||||
|
# could be used to boost search results
|
||||||
|
|
||||||
|
query = "neural networks"
|
||||||
|
logger.info(f"Searching for: '{query}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Standard search
|
||||||
|
standard_results = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text=query,
|
query_text=query,
|
||||||
save_interaction=True
|
save_interaction=False, # Don't add more interactions
|
||||||
|
top_k=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Standard search returned {len(standard_results) if standard_results else 0} results")
|
||||||
|
|
||||||
|
# Note: To actually use frequency_weight in scoring, you would need to:
|
||||||
|
# 1. Modify the retrieval/ranking logic to consider frequency_weight
|
||||||
|
# 2. Add frequency_weight as a scoring factor in the completion strategy
|
||||||
|
# 3. Use it in analytics dashboards to show popular topics
|
||||||
|
|
||||||
|
logger.info("\nFrequency weights can now be used for:")
|
||||||
|
logger.info(" - Boosting frequently-accessed nodes in search rankings")
|
||||||
|
logger.info(" - Adjusting triplet importance scores")
|
||||||
|
logger.info(" - Building usage analytics dashboards")
|
||||||
|
logger.info(" - Identifying 'hot' topics in the knowledge graph")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Demonstration search failed: {e}")
|
||||||
|
|
||||||
# Run usage frequency tracking
|
|
||||||
await cognee.memify(
|
|
||||||
*usage_frequency_pipeline_entry(cognee.graph_adapter)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search and display frequency weights
|
async def main():
|
||||||
results = await cognee.search(
|
"""Main execution flow."""
|
||||||
query_text="Find nodes with frequency weights",
|
logger.info("=" * 80)
|
||||||
query_type=SearchType.NODE_PROPERTIES,
|
logger.info("Usage Frequency Tracking Example")
|
||||||
properties=["frequency_weight"]
|
logger.info("=" * 80)
|
||||||
)
|
|
||||||
|
try:
|
||||||
|
# Step 1: Setup knowledge base
|
||||||
|
await setup_knowledge_base()
|
||||||
|
|
||||||
|
# Step 2: Simulate user searches with save_interaction=True
|
||||||
|
search_count = await simulate_user_searches()
|
||||||
|
|
||||||
|
if search_count == 0:
|
||||||
|
logger.warning("No searches completed - cannot demonstrate frequency tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Run frequency extraction and enrichment
|
||||||
|
# You can use either method - both accomplish the same thing
|
||||||
|
|
||||||
|
# Option A: Using the convenience function (recommended)
|
||||||
|
stats = await run_frequency_pipeline_method2()
|
||||||
|
|
||||||
|
# Option B: Using the pipeline creation function (for custom pipelines)
|
||||||
|
# stats = await run_frequency_pipeline_method1()
|
||||||
|
|
||||||
|
# Step 4: Analyze the results
|
||||||
|
weighted_nodes = await analyze_frequency_weights()
|
||||||
|
|
||||||
|
# Step 5: Demonstrate retrieval usage
|
||||||
|
await demonstrate_retrieval_with_frequencies()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
logger.info("\n" + "=" * 80)
|
||||||
|
logger.info("SUMMARY")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info(f"Searches performed: {search_count}")
|
||||||
|
logger.info(f"Interactions tracked: {stats.get('interactions_in_window', 0)}")
|
||||||
|
logger.info(f"Nodes weighted: {len(weighted_nodes)}")
|
||||||
|
logger.info(f"Time window: {stats.get('time_window_days', 0)} days")
|
||||||
|
logger.info("\nFrequency weights have been added to the graph!")
|
||||||
|
logger.info("These can now be used in retrieval, ranking, and analytics.")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Example failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
print("Nodes with Frequency Weights:")
|
|
||||||
for result in results[0]["search_result"][0]:
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue