chore: removed inconsistency in node properties btw task, e2e example and test codes

This commit is contained in:
Christina_Raichel_Francis 2026-01-05 22:22:47 +00:00
parent 931c5f3096
commit e0c7e68dd6
3 changed files with 1141 additions and 105 deletions

View file

@ -1,8 +1,12 @@
# cognee/tasks/memify/extract_usage_frequency.py
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = get_logger("extract_usage_frequency")
async def extract_usage_frequency(
subgraphs: List[CogneeGraph],
@ -10,35 +14,93 @@ async def extract_usage_frequency(
min_interaction_threshold: int = 1
) -> Dict[str, Any]:
"""
Extract usage frequency from CogneeUserInteraction nodes
Extract usage frequency from CogneeUserInteraction nodes.
:param subgraphs: List of graph subgraphs
:param time_window: Time window to consider for interactions
:param min_interaction_threshold: Minimum interactions to track
:return: Dictionary of usage frequencies
When save_interaction=True in cognee.search(), the system creates:
- CogneeUserInteraction nodes (representing the query/answer interaction)
- used_graph_element_to_answer edges (connecting interactions to graph elements used)
This function tallies how often each graph element is referenced via these edges,
enabling frequency-based ranking in downstream retrievers.
:param subgraphs: List of CogneeGraph instances containing interaction data
:param time_window: Time window to consider for interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
:return: Dictionary containing node frequencies, edge frequencies, and metadata
"""
current_time = datetime.now()
cutoff_time = current_time - time_window
# Track frequencies for graph elements (nodes and edges)
node_frequencies = {}
edge_frequencies = {}
relationship_type_frequencies = {}
# Track interaction metadata
interaction_count = 0
interactions_in_window = 0
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
for subgraph in subgraphs:
# Filter CogneeUserInteraction nodes within time window
user_interactions = [
interaction for interaction in subgraph.nodes
if (interaction.get('type') == 'CogneeUserInteraction' and
current_time - datetime.fromisoformat(interaction.get('timestamp', current_time.isoformat())) <= time_window)
]
# Find all CogneeUserInteraction nodes
interaction_nodes = {}
for node_id, node in subgraph.nodes.items():
node_type = node.attributes.get('type') or node.attributes.get('node_type')
if node_type == 'CogneeUserInteraction':
# Parse and validate timestamp
timestamp_str = node.attributes.get('timestamp') or node.attributes.get('created_at')
if timestamp_str:
try:
interaction_time = datetime.fromisoformat(timestamp_str)
interaction_nodes[node_id] = {
'node': node,
'timestamp': interaction_time,
'in_window': interaction_time >= cutoff_time
}
interaction_count += 1
if interaction_time >= cutoff_time:
interactions_in_window += 1
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}")
# Count node and edge frequencies
for interaction in user_interactions:
target_node_id = interaction.get('target_node_id')
edge_type = interaction.get('edge_type')
# Process edges to find graph elements used in interactions
for edge in subgraph.edges:
relationship_type = edge.attributes.get('relationship_type')
if target_node_id:
node_frequencies[target_node_id] = node_frequencies.get(target_node_id, 0) + 1
# Look for 'used_graph_element_to_answer' edges
if relationship_type == 'used_graph_element_to_answer':
# node1 should be the CogneeUserInteraction, node2 is the graph element
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# Check if source is an interaction node in our time window
if source_id in interaction_nodes:
interaction_data = interaction_nodes[source_id]
if interaction_data['in_window']:
# Count the graph element (target node) being used
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
# Also track what type of element it is for analytics
target_node = subgraph.get_node(target_id)
if target_node:
element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type')
if element_type:
relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1
if edge_type:
edge_frequencies[edge_type] = edge_frequencies.get(edge_type, 0) + 1
# Also track general edge usage patterns
elif relationship_type and relationship_type != 'used_graph_element_to_answer':
# Check if either endpoint is referenced in a recent interaction
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# If this edge connects to any frequently accessed nodes, track the edge type
if source_id in node_frequencies or target_id in node_frequencies:
edge_key = f"{relationship_type}:{source_id}:{target_id}"
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
# Filter frequencies above threshold
filtered_node_frequencies = {
@ -47,55 +109,292 @@ async def extract_usage_frequency(
}
filtered_edge_frequencies = {
edge_type: freq for edge_type, freq in edge_frequencies.items()
edge_key: freq for edge_key, freq in edge_frequencies.items()
if freq >= min_interaction_threshold
}
logger.info(
f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
)
logger.info(
f"Found {len(filtered_node_frequencies)} nodes and {len(filtered_edge_frequencies)} edges "
f"above threshold (min: {min_interaction_threshold})"
)
logger.info(f"Element type distribution: {relationship_type_frequencies}")
return {
'node_frequencies': filtered_node_frequencies,
'edge_frequencies': filtered_edge_frequencies,
'last_processed_timestamp': current_time.isoformat()
'element_type_frequencies': relationship_type_frequencies,
'total_interactions': interaction_count,
'interactions_in_window': interactions_in_window,
'time_window_days': time_window.days,
'last_processed_timestamp': current_time.isoformat(),
'cutoff_timestamp': cutoff_time.isoformat()
}
async def add_frequency_weights(
graph_adapter,
graph_adapter: GraphDBInterface,
usage_frequencies: Dict[str, Any]
) -> None:
"""
Add frequency weights to graph nodes and edges
Add frequency weights to graph nodes and edges using the graph adapter.
:param graph_adapter: Graph database adapter
:param usage_frequencies: Calculated usage frequencies
Uses the "get → tweak dict → update" contract consistent with graph adapters.
Writes frequency_weight properties back to the graph for use in:
- Ranking frequently referenced entities higher during retrieval
- Adjusting scoring for completion strategies
- Exposing usage metrics in dashboards or audits
:param graph_adapter: Graph database adapter interface
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
"""
# Update node frequencies
for node_id, frequency in usage_frequencies['node_frequencies'].items():
node_frequencies = usage_frequencies.get('node_frequencies', {})
edge_frequencies = usage_frequencies.get('edge_frequencies', {})
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
# Update node frequencies using get → tweak → update pattern
nodes_updated = 0
nodes_failed = 0
for node_id, frequency in node_frequencies.items():
try:
node = graph_adapter.get_node(node_id)
if node:
node_properties = node.get_properties() or {}
node_properties['frequency_weight'] = frequency
graph_adapter.update_node(node_id, node_properties)
# Get current node data
node_data = await graph_adapter.get_node_by_id(node_id)
if node_data:
# Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict):
properties = node_data.get('properties', {})
else:
# Handle case where node_data might be a node object
properties = getattr(node_data, 'properties', {}) or {}
# Update with frequency weight
properties['frequency_weight'] = frequency
# Also store when this was last updated
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
# Write back via adapter
await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
print(f"Error updating node {node_id}: {e}")
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
# Note: Edge frequency update might require backend-specific implementation
print("Edge frequency update might need backend-specific handling")
logger.info(
f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed"
)
# Update edge frequencies
# Note: Edge property updates are backend-specific
if edge_frequencies:
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
edges_updated = 0
edges_failed = 0
for edge_key, frequency in edge_frequencies.items():
try:
# Parse edge key: "relationship_type:source_id:target_id"
parts = edge_key.split(':', 2)
if len(parts) == 3:
relationship_type, source_id, target_id = parts
# Try to update edge if adapter supports it
if hasattr(graph_adapter, 'update_edge_properties'):
edge_properties = {
'frequency_weight': frequency,
'frequency_updated_at': usage_frequencies.get('last_processed_timestamp')
}
await graph_adapter.update_edge_properties(
source_id,
target_id,
relationship_type,
edge_properties
)
edges_updated += 1
else:
# Fallback: store in metadata or log
logger.debug(
f"Adapter doesn't support update_edge_properties for "
f"{relationship_type} ({source_id} -> {target_id})"
)
except Exception as e:
logger.error(f"Error updating edge {edge_key}: {e}")
edges_failed += 1
if edges_updated > 0:
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
else:
logger.info(
"Edge frequency updates skipped (adapter may not support edge property updates)"
)
# Store aggregate statistics as metadata if supported
if hasattr(graph_adapter, 'set_metadata'):
try:
metadata = {
'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}),
'total_interactions': usage_frequencies.get('total_interactions', 0),
'interactions_in_window': usage_frequencies.get('interactions_in_window', 0),
'last_frequency_update': usage_frequencies.get('last_processed_timestamp')
}
await graph_adapter.set_metadata('usage_frequency_stats', metadata)
logger.info("Stored usage frequency statistics as metadata")
except Exception as e:
logger.warning(f"Could not store usage statistics as metadata: {e}")
def usage_frequency_pipeline_entry(graph_adapter):
async def create_usage_frequency_pipeline(
graph_adapter: GraphDBInterface,
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1,
batch_size: int = 100
) -> tuple:
"""
Memify pipeline entry for usage frequency tracking
Create memify pipeline entry for usage frequency tracking.
This follows the same pattern as feedback enrichment flows, allowing
the frequency update to run end-to-end in a custom memify pipeline.
Use case example:
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=my_adapter,
time_window=timedelta(days=30),
min_interaction_threshold=2
)
# Run in memify pipeline
pipeline = Pipeline(extraction_tasks + enrichment_tasks)
results = await pipeline.run()
:param graph_adapter: Graph database adapter
:return: Usage frequency results
:param time_window: Time window for counting interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
:param batch_size: Batch size for processing (default: 100)
:return: Tuple of (extraction_tasks, enrichment_tasks)
"""
logger.info("Creating usage frequency pipeline")
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
extraction_tasks = [
Task(extract_usage_frequency,
time_window=timedelta(days=7),
min_interaction_threshold=1)
Task(
extract_usage_frequency,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold
)
]
enrichment_tasks = [
Task(add_frequency_weights, task_config={"batch_size": 1})
Task(
add_frequency_weights,
graph_adapter=graph_adapter,
task_config={"batch_size": batch_size}
)
]
return extraction_tasks, enrichment_tasks
return extraction_tasks, enrichment_tasks
async def run_usage_frequency_update(
graph_adapter: GraphDBInterface,
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1
) -> Dict[str, Any]:
"""
Convenience function to run the complete usage frequency update pipeline.
This is the main entry point for updating frequency weights on graph elements
based on CogneeUserInteraction data from cognee.search(save_interaction=True).
Example usage:
# After running searches with save_interaction=True
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
# Get the graph with interactions
graph = await get_cognee_graph_with_interactions()
# Update frequency weights
stats = await run_usage_frequency_update(
graph_adapter=graph_adapter,
subgraphs=[graph],
time_window=timedelta(days=30), # Last 30 days
min_interaction_threshold=2 # At least 2 uses
)
print(f"Updated {len(stats['node_frequencies'])} nodes")
:param graph_adapter: Graph database adapter
:param subgraphs: List of CogneeGraph instances with interaction data
:param time_window: Time window for counting interactions
:param min_interaction_threshold: Minimum interactions to track
:return: Usage frequency statistics
"""
logger.info("Starting usage frequency update")
try:
# Extract frequencies from interaction data
usage_frequencies = await extract_usage_frequency(
subgraphs=subgraphs,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold
)
# Add frequency weights back to the graph
await add_frequency_weights(
graph_adapter=graph_adapter,
usage_frequencies=usage_frequencies
)
logger.info("Usage frequency update completed successfully")
logger.info(
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
)
return usage_frequencies
except Exception as e:
logger.error(f"Error during usage frequency update: {str(e)}")
raise
async def get_most_frequent_elements(
graph_adapter: GraphDBInterface,
top_n: int = 10,
element_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Retrieve the most frequently accessed graph elements.
Useful for analytics dashboards and understanding user behavior.
:param graph_adapter: Graph database adapter
:param top_n: Number of top elements to return
:param element_type: Optional filter by element type
:return: List of elements with their frequency weights
"""
logger.info(f"Retrieving top {top_n} most frequent elements")
# This would need to be implemented based on the specific graph adapter's query capabilities
# Pseudocode:
# results = await graph_adapter.query_nodes_by_property(
# property_name='frequency_weight',
# order_by='DESC',
# limit=top_n,
# filters={'type': element_type} if element_type else None
# )
logger.warning("get_most_frequent_elements needs adapter-specific implementation")
return []

View file

@ -1,42 +1,503 @@
# cognee/tests/test_usage_frequency.py
"""
Test suite for usage frequency tracking functionality.
Tests cover:
- Frequency extraction from CogneeUserInteraction nodes
- Time window filtering
- Frequency weight application to graph
- Edge cases and error handling
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from cognee.tasks.memify.extract_usage_frequency import extract_usage_frequency, add_frequency_weights
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Dict, Any
from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency,
add_frequency_weights,
create_usage_frequency_pipeline,
run_usage_frequency_update,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
def create_mock_node(node_id: str, attributes: Dict[str, Any]) -> Node:
"""Helper to create mock Node objects."""
node = Node(node_id, attributes)
return node
def create_mock_edge(node1: Node, node2: Node, relationship_type: str, attributes: Dict[str, Any] = None) -> Edge:
"""Helper to create mock Edge objects."""
edge_attrs = attributes or {}
edge_attrs['relationship_type'] = relationship_type
edge = Edge(node1, node2, attributes=edge_attrs, directed=True)
return edge
def create_interaction_graph(
interaction_count: int = 3,
target_nodes: list = None,
time_offset_hours: int = 0
) -> CogneeGraph:
"""
Create a mock CogneeGraph with interaction nodes.
:param interaction_count: Number of interactions to create
:param target_nodes: List of target node IDs to reference
:param time_offset_hours: Hours to offset timestamp (negative = past)
:return: CogneeGraph with mocked interaction data
"""
graph = CogneeGraph(directed=True)
if target_nodes is None:
target_nodes = ['node1', 'node2', 'node3']
# Create some target graph element nodes
element_nodes = {}
for i, node_id in enumerate(target_nodes):
element_node = create_mock_node(
node_id,
{
'type': 'DocumentChunk',
'text': f'This is content for {node_id}',
'name': f'Element {i+1}'
}
)
graph.add_node(element_node)
element_nodes[node_id] = element_node
# Create interaction nodes and edges
timestamp = datetime.now() + timedelta(hours=time_offset_hours)
for i in range(interaction_count):
# Create interaction node
interaction_id = f'interaction_{i}'
target_id = target_nodes[i % len(target_nodes)]
interaction_node = create_mock_node(
interaction_id,
{
'type': 'CogneeUserInteraction',
'timestamp': timestamp.isoformat(),
'query_text': f'Sample query {i}',
'target_node_id': target_id # Also store in attributes for completeness
}
)
graph.add_node(interaction_node)
# Create edge from interaction to target element
target_element = element_nodes[target_id]
edge = create_mock_edge(
interaction_node,
target_element,
'used_graph_element_to_answer',
{'timestamp': timestamp.isoformat()}
)
graph.add_edge(edge)
return graph
@pytest.mark.asyncio
async def test_extract_usage_frequency():
# Mock CogneeGraph with user interactions
mock_subgraphs = [{
'nodes': [
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node1',
'edge_type': 'viewed',
'timestamp': datetime.now().isoformat()
},
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node1',
'edge_type': 'viewed',
'timestamp': datetime.now().isoformat()
},
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node2',
'edge_type': 'referenced',
'timestamp': datetime.now().isoformat()
}
]
}]
# Test frequency extraction
async def test_extract_usage_frequency_basic():
"""Test basic frequency extraction with simple interaction data."""
# Create mock graph with 3 interactions
# node1 referenced twice, node2 referenced once
mock_graph = create_interaction_graph(
interaction_count=3,
target_nodes=['node1', 'node1', 'node2']
)
# Extract frequencies
result = await extract_usage_frequency(
mock_subgraphs,
time_window=timedelta(days=1),
subgraphs=[mock_graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
assert 'node1' in result['node_frequencies']
# Assertions
assert 'node_frequencies' in result
assert 'edge_frequencies' in result
assert result['node_frequencies']['node1'] == 2
assert result['edge_frequencies']['viewed'] == 2
assert result['node_frequencies']['node2'] == 1
assert result['total_interactions'] == 3
assert result['interactions_in_window'] == 3
@pytest.mark.asyncio
async def test_extract_usage_frequency_time_window():
"""Test that time window filtering works correctly."""
# Create two graphs: one recent, one old
recent_graph = create_interaction_graph(
interaction_count=2,
target_nodes=['node1', 'node2'],
time_offset_hours=-1 # 1 hour ago
)
old_graph = create_interaction_graph(
interaction_count=2,
target_nodes=['node3', 'node4'],
time_offset_hours=-200 # 200 hours ago (> 7 days)
)
# Extract with 7-day window
result = await extract_usage_frequency(
subgraphs=[recent_graph, old_graph],
time_window=timedelta(days=7),
min_interaction_threshold=1
)
# Only recent interactions should be counted
assert result['total_interactions'] == 4 # All interactions found
assert result['interactions_in_window'] == 2 # Only recent ones counted
assert 'node1' in result['node_frequencies']
assert 'node2' in result['node_frequencies']
assert 'node3' not in result['node_frequencies'] # Too old
assert 'node4' not in result['node_frequencies'] # Too old
@pytest.mark.asyncio
async def test_extract_usage_frequency_threshold():
"""Test minimum interaction threshold filtering."""
# Create graph where node1 has 3 interactions, node2 has 1
mock_graph = create_interaction_graph(
interaction_count=4,
target_nodes=['node1', 'node1', 'node1', 'node2']
)
# Extract with threshold of 2
result = await extract_usage_frequency(
subgraphs=[mock_graph],
time_window=timedelta(days=1),
min_interaction_threshold=2
)
# Only node1 should be in results (3 >= 2)
assert 'node1' in result['node_frequencies']
assert result['node_frequencies']['node1'] == 3
assert 'node2' not in result['node_frequencies'] # Below threshold
@pytest.mark.asyncio
async def test_extract_usage_frequency_multiple_graphs():
"""Test extraction across multiple subgraphs."""
graph1 = create_interaction_graph(
interaction_count=2,
target_nodes=['node1', 'node2']
)
graph2 = create_interaction_graph(
interaction_count=2,
target_nodes=['node1', 'node3']
)
result = await extract_usage_frequency(
subgraphs=[graph1, graph2],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
# node1 should have frequency of 2 (once from each graph)
assert result['node_frequencies']['node1'] == 2
assert result['node_frequencies']['node2'] == 1
assert result['node_frequencies']['node3'] == 1
assert result['total_interactions'] == 4
@pytest.mark.asyncio
async def test_extract_usage_frequency_empty_graph():
"""Test handling of empty graphs."""
empty_graph = CogneeGraph(directed=True)
result = await extract_usage_frequency(
subgraphs=[empty_graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
assert result['node_frequencies'] == {}
assert result['edge_frequencies'] == {}
assert result['total_interactions'] == 0
assert result['interactions_in_window'] == 0
@pytest.mark.asyncio
async def test_extract_usage_frequency_invalid_timestamps():
"""Test handling of invalid timestamp formats."""
graph = CogneeGraph(directed=True)
# Create interaction with invalid timestamp
bad_interaction = create_mock_node(
'bad_interaction',
{
'type': 'CogneeUserInteraction',
'timestamp': 'not-a-valid-timestamp',
'target_node_id': 'node1'
}
)
graph.add_node(bad_interaction)
# Should not crash, just skip invalid interaction
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
assert result['total_interactions'] == 0 # Invalid interaction not counted
@pytest.mark.asyncio
async def test_extract_usage_frequency_element_type_tracking():
"""Test that element type frequencies are tracked."""
graph = CogneeGraph(directed=True)
# Create different types of target nodes
chunk_node = create_mock_node('chunk1', {'type': 'DocumentChunk', 'text': 'content'})
entity_node = create_mock_node('entity1', {'type': 'Entity', 'name': 'Alice'})
graph.add_node(chunk_node)
graph.add_node(entity_node)
# Create interactions pointing to each
timestamp = datetime.now().isoformat()
for i, target in enumerate([chunk_node, chunk_node, entity_node]):
interaction = create_mock_node(
f'interaction_{i}',
{'type': 'CogneeUserInteraction', 'timestamp': timestamp}
)
graph.add_node(interaction)
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
graph.add_edge(edge)
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
# Check element type frequencies
assert 'element_type_frequencies' in result
assert result['element_type_frequencies']['DocumentChunk'] == 2
assert result['element_type_frequencies']['Entity'] == 1
@pytest.mark.asyncio
async def test_add_frequency_weights():
"""Test adding frequency weights to graph via adapter."""
# Mock graph adapter
mock_adapter = AsyncMock()
mock_adapter.get_node_by_id = AsyncMock(return_value={
'id': 'node1',
'properties': {'type': 'DocumentChunk', 'text': 'content'}
})
mock_adapter.update_node_properties = AsyncMock()
# Mock usage frequencies
usage_frequencies = {
'node_frequencies': {'node1': 5, 'node2': 3},
'edge_frequencies': {},
'last_processed_timestamp': datetime.now().isoformat()
}
# Add weights
await add_frequency_weights(mock_adapter, usage_frequencies)
# Verify adapter methods were called
assert mock_adapter.get_node_by_id.call_count == 2
assert mock_adapter.update_node_properties.call_count == 2
# Verify the properties passed to update include frequency_weight
calls = mock_adapter.update_node_properties.call_args_list
properties_updated = calls[0][0][1] # Second argument of first call
assert 'frequency_weight' in properties_updated
assert properties_updated['frequency_weight'] == 5
@pytest.mark.asyncio
async def test_add_frequency_weights_node_not_found():
"""Test handling when node is not found in graph."""
mock_adapter = AsyncMock()
mock_adapter.get_node_by_id = AsyncMock(return_value=None) # Node not found
mock_adapter.update_node_properties = AsyncMock()
usage_frequencies = {
'node_frequencies': {'nonexistent_node': 5},
'edge_frequencies': {},
'last_processed_timestamp': datetime.now().isoformat()
}
# Should not crash
await add_frequency_weights(mock_adapter, usage_frequencies)
# Update should not be called since node wasn't found
assert mock_adapter.update_node_properties.call_count == 0
@pytest.mark.asyncio
async def test_add_frequency_weights_with_metadata_support():
"""Test that metadata is stored when adapter supports it."""
mock_adapter = AsyncMock()
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
mock_adapter.update_node_properties = AsyncMock()
mock_adapter.set_metadata = AsyncMock() # Adapter supports metadata
usage_frequencies = {
'node_frequencies': {'node1': 5},
'edge_frequencies': {},
'element_type_frequencies': {'DocumentChunk': 5},
'total_interactions': 10,
'interactions_in_window': 8,
'last_processed_timestamp': datetime.now().isoformat()
}
await add_frequency_weights(mock_adapter, usage_frequencies)
# Verify metadata was stored
mock_adapter.set_metadata.assert_called_once()
metadata_key, metadata_value = mock_adapter.set_metadata.call_args[0]
assert metadata_key == 'usage_frequency_stats'
assert 'total_interactions' in metadata_value
assert metadata_value['total_interactions'] == 10
@pytest.mark.asyncio
async def test_create_usage_frequency_pipeline():
"""Test pipeline creation returns correct task structure."""
mock_adapter = AsyncMock()
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=mock_adapter,
time_window=timedelta(days=7),
min_interaction_threshold=2,
batch_size=50
)
# Verify task structure
assert len(extraction_tasks) == 1
assert len(enrichment_tasks) == 1
# Verify extraction task
extraction_task = extraction_tasks[0]
assert hasattr(extraction_task, 'function')
# Verify enrichment task
enrichment_task = enrichment_tasks[0]
assert hasattr(enrichment_task, 'function')
@pytest.mark.asyncio
async def test_run_usage_frequency_update_integration():
"""Test the full end-to-end update process."""
# Create mock graph with interactions
mock_graph = create_interaction_graph(
interaction_count=5,
target_nodes=['node1', 'node1', 'node2', 'node3', 'node1']
)
# Mock adapter
mock_adapter = AsyncMock()
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
mock_adapter.update_node_properties = AsyncMock()
# Run the full update
stats = await run_usage_frequency_update(
graph_adapter=mock_adapter,
subgraphs=[mock_graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
# Verify stats
assert stats['total_interactions'] == 5
assert stats['node_frequencies']['node1'] == 3
assert stats['node_frequencies']['node2'] == 1
assert stats['node_frequencies']['node3'] == 1
# Verify adapter was called to update nodes
assert mock_adapter.update_node_properties.call_count == 3 # 3 unique nodes
@pytest.mark.asyncio
async def test_extract_usage_frequency_no_used_graph_element_edges():
"""Test handling when there are interactions but no proper edges."""
graph = CogneeGraph(directed=True)
# Create interaction node
interaction = create_mock_node(
'interaction1',
{
'type': 'CogneeUserInteraction',
'timestamp': datetime.now().isoformat(),
'target_node_id': 'node1'
}
)
graph.add_node(interaction)
# Don't add any edges - interaction is orphaned
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
# Should find the interaction but no frequencies (no edges)
assert result['total_interactions'] == 1
assert result['node_frequencies'] == {}
@pytest.mark.asyncio
async def test_extract_usage_frequency_alternative_timestamp_field():
"""Test that 'created_at' field works as fallback for timestamp."""
graph = CogneeGraph(directed=True)
target = create_mock_node('target1', {'type': 'DocumentChunk'})
graph.add_node(target)
# Use 'created_at' instead of 'timestamp'
interaction = create_mock_node(
'interaction1',
{
'type': 'CogneeUserInteraction',
'created_at': datetime.now().isoformat() # Alternative field
}
)
graph.add_node(interaction)
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
graph.add_edge(edge)
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=1),
min_interaction_threshold=1
)
# Should still work with created_at
assert result['total_interactions'] == 1
assert 'target1' in result['node_frequencies']
def test_imports():
"""Test that all required modules can be imported."""
from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency,
add_frequency_weights,
create_usage_frequency_pipeline,
run_usage_frequency_update,
)
assert extract_usage_frequency is not None
assert add_frequency_weights is not None
assert create_usage_frequency_pipeline is not None
assert run_usage_frequency_update is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -1,49 +1,325 @@
# cognee/examples/usage_frequency_example.py
"""
End-to-end example demonstrating usage frequency tracking in Cognee.
This example shows how to:
1. Add data and build a knowledge graph
2. Run searches with save_interaction=True to track usage
3. Extract and apply frequency weights using the memify pipeline
4. Query and analyze the frequency data
The frequency weights can be used to:
- Rank frequently referenced entities higher during retrieval
- Adjust scoring for completion strategies
- Expose usage metrics in dashboards or audits
"""
import asyncio
from datetime import timedelta
from typing import List
import cognee
from cognee.api.v1.search import SearchType
from cognee.tasks.memify.extract_usage_frequency import usage_frequency_pipeline_entry
from cognee.tasks.memify.extract_usage_frequency import (
create_usage_frequency_pipeline,
run_usage_frequency_update,
)
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.shared.logging_utils import get_logger
async def main():
# Reset cognee state
logger = get_logger("usage_frequency_example")
async def setup_knowledge_base():
"""Set up a fresh knowledge base with sample data."""
logger.info("Setting up knowledge base...")
# Reset cognee state for clean slate
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Sample conversation
# Sample conversation about AI/ML topics
conversation = [
"Alice discusses machine learning",
"Bob asks about neural networks",
"Alice explains deep learning concepts",
"Bob wants more details about neural networks"
"Alice discusses machine learning algorithms and their applications in computer vision.",
"Bob asks about neural networks and how they differ from traditional algorithms.",
"Alice explains deep learning concepts including CNNs and transformers.",
"Bob wants more details about neural networks and backpropagation.",
"Alice describes reinforcement learning and its use in robotics.",
"Bob inquires about natural language processing and transformers.",
]
# Add conversation and cognify
await cognee.add(conversation)
# Add conversation data and build knowledge graph
logger.info("Adding conversation data...")
await cognee.add(conversation, dataset_name="ai_ml_conversation")
logger.info("Building knowledge graph (cognify)...")
await cognee.cognify()
logger.info("Knowledge base setup complete")
# Perform some searches to generate interactions
for query in ["machine learning", "neural networks", "deep learning"]:
await cognee.search(
async def simulate_user_searches():
"""Simulate multiple user searches to generate interaction data."""
logger.info("Simulating user searches with save_interaction=True...")
# Different queries that will create CogneeUserInteraction nodes
queries = [
"What is machine learning?",
"Explain neural networks",
"Tell me about deep learning",
"What are neural networks?", # Repeat to increase frequency
"How does machine learning work?",
"Describe transformers in NLP",
"What is reinforcement learning?",
"Explain neural networks again", # Another repeat
]
search_count = 0
for query in queries:
try:
logger.info(f"Searching: '{query}'")
results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
save_interaction=True, # Critical: saves interaction to graph
top_k=5
)
search_count += 1
logger.debug(f"Search completed, got {len(results) if results else 0} results")
except Exception as e:
logger.warning(f"Search failed for '{query}': {e}")
logger.info(f"Completed {search_count} searches with interactions saved")
return search_count
async def retrieve_interaction_graph() -> List[CogneeGraph]:
"""Retrieve the graph containing interaction nodes."""
logger.info("Retrieving graph with interaction data...")
graph_engine = await get_graph_engine()
graph = CogneeGraph()
# Project the full graph including CogneeUserInteraction nodes
await graph.project_graph_from_db(
adapter=graph_engine,
node_properties_to_project=["type", "node_type", "timestamp", "created_at", "text", "name"],
edge_properties_to_project=["relationship_type", "timestamp", "created_at"],
directed=True,
)
logger.info(f"Retrieved graph: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
# Count interaction nodes for verification
interaction_count = sum(
1 for node in graph.nodes.values()
if node.attributes.get('type') == 'CogneeUserInteraction' or
node.attributes.get('node_type') == 'CogneeUserInteraction'
)
logger.info(f"Found {interaction_count} CogneeUserInteraction nodes in graph")
return [graph]
async def run_frequency_pipeline_method1():
"""Method 1: Using the pipeline creation function."""
logger.info("\n=== Method 1: Using create_usage_frequency_pipeline ===")
graph_engine = await get_graph_engine()
subgraphs = await retrieve_interaction_graph()
# Create the pipeline tasks
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=graph_engine,
time_window=timedelta(days=30), # Last 30 days
min_interaction_threshold=1, # Count all interactions
batch_size=100
)
logger.info("Running extraction tasks...")
# Note: In real memify pipeline, these would be executed by the pipeline runner
# For this example, we'll execute them manually
for task in extraction_tasks:
if hasattr(task, 'function'):
result = await task.function(
subgraphs=subgraphs,
time_window=timedelta(days=30),
min_interaction_threshold=1
)
logger.info(f"Extraction result: {result.get('interactions_in_window')} interactions processed")
logger.info("Running enrichment tasks...")
for task in enrichment_tasks:
if hasattr(task, 'function'):
await task.function(
graph_adapter=graph_engine,
usage_frequencies=result
)
return result
async def run_frequency_pipeline_method2():
"""Method 2: Using the convenience function."""
logger.info("\n=== Method 2: Using run_usage_frequency_update ===")
graph_engine = await get_graph_engine()
subgraphs = await retrieve_interaction_graph()
# Run the complete pipeline in one call
stats = await run_usage_frequency_update(
graph_adapter=graph_engine,
subgraphs=subgraphs,
time_window=timedelta(days=30),
min_interaction_threshold=1
)
logger.info("Frequency update statistics:")
logger.info(f" Total interactions: {stats['total_interactions']}")
logger.info(f" Interactions in window: {stats['interactions_in_window']}")
logger.info(f" Nodes with frequency weights: {len(stats['node_frequencies'])}")
logger.info(f" Element types: {stats.get('element_type_frequencies', {})}")
return stats
async def analyze_frequency_weights():
"""Analyze and display the frequency weights that were added."""
logger.info("\n=== Analyzing Frequency Weights ===")
graph_engine = await get_graph_engine()
graph = CogneeGraph()
# Project graph with frequency weights
await graph.project_graph_from_db(
adapter=graph_engine,
node_properties_to_project=[
"type",
"node_type",
"text",
"name",
"frequency_weight", # Our added property
"frequency_updated_at"
],
edge_properties_to_project=["relationship_type"],
directed=True,
)
# Find nodes with frequency weights
weighted_nodes = []
for node_id, node in graph.nodes.items():
freq_weight = node.attributes.get('frequency_weight')
if freq_weight is not None:
weighted_nodes.append({
'id': node_id,
'type': node.attributes.get('type') or node.attributes.get('node_type'),
'text': node.attributes.get('text', '')[:100], # First 100 chars
'name': node.attributes.get('name', ''),
'frequency_weight': freq_weight,
'updated_at': node.attributes.get('frequency_updated_at')
})
# Sort by frequency (descending)
weighted_nodes.sort(key=lambda x: x['frequency_weight'], reverse=True)
logger.info(f"\nFound {len(weighted_nodes)} nodes with frequency weights:")
logger.info("\nTop 10 Most Frequently Referenced Elements:")
logger.info("-" * 80)
for i, node in enumerate(weighted_nodes[:10], 1):
logger.info(f"\n{i}. Frequency: {node['frequency_weight']}")
logger.info(f" Type: {node['type']}")
logger.info(f" Name: {node['name']}")
logger.info(f" Text: {node['text']}")
logger.info(f" ID: {node['id'][:50]}...")
return weighted_nodes
async def demonstrate_retrieval_with_frequencies():
"""Demonstrate how frequency weights can be used in retrieval."""
logger.info("\n=== Demonstrating Retrieval with Frequency Weights ===")
# This is a conceptual demonstration of how frequency weights
# could be used to boost search results
query = "neural networks"
logger.info(f"Searching for: '{query}'")
try:
# Standard search
standard_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
save_interaction=True
save_interaction=False, # Don't add more interactions
top_k=5
)
logger.info(f"Standard search returned {len(standard_results) if standard_results else 0} results")
# Note: To actually use frequency_weight in scoring, you would need to:
# 1. Modify the retrieval/ranking logic to consider frequency_weight
# 2. Add frequency_weight as a scoring factor in the completion strategy
# 3. Use it in analytics dashboards to show popular topics
logger.info("\nFrequency weights can now be used for:")
logger.info(" - Boosting frequently-accessed nodes in search rankings")
logger.info(" - Adjusting triplet importance scores")
logger.info(" - Building usage analytics dashboards")
logger.info(" - Identifying 'hot' topics in the knowledge graph")
except Exception as e:
logger.warning(f"Demonstration search failed: {e}")
# Run usage frequency tracking
await cognee.memify(
*usage_frequency_pipeline_entry(cognee.graph_adapter)
)
# Search and display frequency weights
results = await cognee.search(
query_text="Find nodes with frequency weights",
query_type=SearchType.NODE_PROPERTIES,
properties=["frequency_weight"]
)
async def main():
"""Main execution flow."""
logger.info("=" * 80)
logger.info("Usage Frequency Tracking Example")
logger.info("=" * 80)
try:
# Step 1: Setup knowledge base
await setup_knowledge_base()
# Step 2: Simulate user searches with save_interaction=True
search_count = await simulate_user_searches()
if search_count == 0:
logger.warning("No searches completed - cannot demonstrate frequency tracking")
return
# Step 3: Run frequency extraction and enrichment
# You can use either method - both accomplish the same thing
# Option A: Using the convenience function (recommended)
stats = await run_frequency_pipeline_method2()
# Option B: Using the pipeline creation function (for custom pipelines)
# stats = await run_frequency_pipeline_method1()
# Step 4: Analyze the results
weighted_nodes = await analyze_frequency_weights()
# Step 5: Demonstrate retrieval usage
await demonstrate_retrieval_with_frequencies()
# Summary
logger.info("\n" + "=" * 80)
logger.info("SUMMARY")
logger.info("=" * 80)
logger.info(f"Searches performed: {search_count}")
logger.info(f"Interactions tracked: {stats.get('interactions_in_window', 0)}")
logger.info(f"Nodes weighted: {len(weighted_nodes)}")
logger.info(f"Time window: {stats.get('time_window_days', 0)} days")
logger.info("\nFrequency weights have been added to the graph!")
logger.info("These can now be used in retrieval, ranking, and analytics.")
logger.info("=" * 80)
except Exception as e:
logger.error(f"Example failed: {e}", exc_info=True)
raise
print("Nodes with Frequency Weights:")
for result in results[0]["search_result"][0]:
print(result)
if __name__ == "__main__":
asyncio.run(main())