refactor: add test and example script

This commit is contained in:
Christina_Raichel_Francis 2025-12-17 18:02:35 +00:00
parent ee29dd1f81
commit 931c5f3096
3 changed files with 189 additions and 4 deletions

View file

@ -1,7 +1,101 @@
# cognee/tasks/memify/extract_usage_frequency.py
from typing import List, Dict, Any
from datetime import datetime, timedelta
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.pipelines.tasks.task import Task
async def extract_subgraph(subgraphs: list[CogneeGraph]):
async def extract_usage_frequency(
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1
) -> Dict[str, Any]:
"""
Extract usage frequency from CogneeUserInteraction nodes
:param subgraphs: List of graph subgraphs
:param time_window: Time window to consider for interactions
:param min_interaction_threshold: Minimum interactions to track
:return: Dictionary of usage frequencies
"""
current_time = datetime.now()
node_frequencies = {}
edge_frequencies = {}
for subgraph in subgraphs:
for edge in subgraph.edges:
yield edge
# Filter CogneeUserInteraction nodes within time window
user_interactions = [
interaction for interaction in subgraph.nodes
if (interaction.get('type') == 'CogneeUserInteraction' and
current_time - datetime.fromisoformat(interaction.get('timestamp', current_time.isoformat())) <= time_window)
]
# Count node and edge frequencies
for interaction in user_interactions:
target_node_id = interaction.get('target_node_id')
edge_type = interaction.get('edge_type')
if target_node_id:
node_frequencies[target_node_id] = node_frequencies.get(target_node_id, 0) + 1
if edge_type:
edge_frequencies[edge_type] = edge_frequencies.get(edge_type, 0) + 1
# Filter frequencies above threshold
filtered_node_frequencies = {
node_id: freq for node_id, freq in node_frequencies.items()
if freq >= min_interaction_threshold
}
filtered_edge_frequencies = {
edge_type: freq for edge_type, freq in edge_frequencies.items()
if freq >= min_interaction_threshold
}
return {
'node_frequencies': filtered_node_frequencies,
'edge_frequencies': filtered_edge_frequencies,
'last_processed_timestamp': current_time.isoformat()
}
async def add_frequency_weights(
graph_adapter,
usage_frequencies: Dict[str, Any]
) -> None:
"""
Add frequency weights to graph nodes and edges
:param graph_adapter: Graph database adapter
:param usage_frequencies: Calculated usage frequencies
"""
# Update node frequencies
for node_id, frequency in usage_frequencies['node_frequencies'].items():
try:
node = graph_adapter.get_node(node_id)
if node:
node_properties = node.get_properties() or {}
node_properties['frequency_weight'] = frequency
graph_adapter.update_node(node_id, node_properties)
except Exception as e:
print(f"Error updating node {node_id}: {e}")
# Note: Edge frequency update might require backend-specific implementation
print("Edge frequency update might need backend-specific handling")
def usage_frequency_pipeline_entry(graph_adapter):
"""
Memify pipeline entry for usage frequency tracking
:param graph_adapter: Graph database adapter
:return: Usage frequency results
"""
extraction_tasks = [
Task(extract_usage_frequency,
time_window=timedelta(days=7),
min_interaction_threshold=1)
]
enrichment_tasks = [
Task(add_frequency_weights, task_config={"batch_size": 1})
]
return extraction_tasks, enrichment_tasks

View file

@ -0,0 +1,42 @@
# cognee/tests/test_usage_frequency.py
import pytest
import asyncio
from datetime import datetime, timedelta
from cognee.tasks.memify.extract_usage_frequency import extract_usage_frequency, add_frequency_weights
@pytest.mark.asyncio
async def test_extract_usage_frequency():
# Mock CogneeGraph with user interactions
mock_subgraphs = [{
'nodes': [
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node1',
'edge_type': 'viewed',
'timestamp': datetime.now().isoformat()
},
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node1',
'edge_type': 'viewed',
'timestamp': datetime.now().isoformat()
},
{
'type': 'CogneeUserInteraction',
'target_node_id': 'node2',
'edge_type': 'referenced',
'timestamp': datetime.now().isoformat()
}
]
}]
# Test frequency extraction
result = await extract_usage_frequency(
mock_subgraphs,
time_window=timedelta(days=1),
min_interaction_threshold=1
)
assert 'node1' in result['node_frequencies']
assert result['node_frequencies']['node1'] == 2
assert result['edge_frequencies']['viewed'] == 2

View file

@ -0,0 +1,49 @@
# cognee/examples/usage_frequency_example.py
import asyncio
import cognee
from cognee.api.v1.search import SearchType
from cognee.tasks.memify.extract_usage_frequency import usage_frequency_pipeline_entry
async def main():
# Reset cognee state
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Sample conversation
conversation = [
"Alice discusses machine learning",
"Bob asks about neural networks",
"Alice explains deep learning concepts",
"Bob wants more details about neural networks"
]
# Add conversation and cognify
await cognee.add(conversation)
await cognee.cognify()
# Perform some searches to generate interactions
for query in ["machine learning", "neural networks", "deep learning"]:
await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
save_interaction=True
)
# Run usage frequency tracking
await cognee.memify(
*usage_frequency_pipeline_entry(cognee.graph_adapter)
)
# Search and display frequency weights
results = await cognee.search(
query_text="Find nodes with frequency weights",
query_type=SearchType.NODE_PROPERTIES,
properties=["frequency_weight"]
)
print("Nodes with Frequency Weights:")
for result in results[0]["search_result"][0]:
print(result)
if __name__ == "__main__":
asyncio.run(main())