chore: ruff format and refactor on contributor PR

This commit is contained in:
Igor Ilic 2026-01-13 15:10:21 +01:00
parent 0d2f66fa1d
commit dce51efbe3
3 changed files with 408 additions and 387 deletions

View file

@ -10,20 +10,20 @@ logger = get_logger("extract_usage_frequency")
async def extract_usage_frequency(
subgraphs: List[CogneeGraph],
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1
min_interaction_threshold: int = 1,
) -> Dict[str, Any]:
"""
Extract usage frequency from CogneeUserInteraction nodes.
When save_interaction=True in cognee.search(), the system creates:
- CogneeUserInteraction nodes (representing the query/answer interaction)
- used_graph_element_to_answer edges (connecting interactions to graph elements used)
This function tallies how often each graph element is referenced via these edges,
enabling frequency-based ranking in downstream retrievers.
:param subgraphs: List of CogneeGraph instances containing interaction data
:param time_window: Time window to consider for interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
@ -31,33 +31,35 @@ async def extract_usage_frequency(
"""
current_time = datetime.now()
cutoff_time = current_time - time_window
# Track frequencies for graph elements (nodes and edges)
node_frequencies = {}
edge_frequencies = {}
relationship_type_frequencies = {}
# Track interaction metadata
interaction_count = 0
interactions_in_window = 0
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
for subgraph in subgraphs:
# Find all CogneeUserInteraction nodes
interaction_nodes = {}
for node_id, node in subgraph.nodes.items():
node_type = node.attributes.get('type') or node.attributes.get('node_type')
if node_type == 'CogneeUserInteraction':
node_type = node.attributes.get("type") or node.attributes.get("node_type")
if node_type == "CogneeUserInteraction":
# Parse and validate timestamp
timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at')
timestamp_value = node.attributes.get("timestamp") or node.attributes.get(
"created_at"
)
if timestamp_value is not None:
try:
# Handle various timestamp formats
interaction_time = None
if isinstance(timestamp_value, datetime):
# Already a Python datetime
interaction_time = timestamp_value
@ -81,24 +83,24 @@ async def extract_usage_frequency(
else:
# ISO format string
interaction_time = datetime.fromisoformat(timestamp_value)
elif hasattr(timestamp_value, 'to_native'):
elif hasattr(timestamp_value, "to_native"):
# Neo4j datetime object - convert to Python datetime
interaction_time = timestamp_value.to_native()
elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'):
elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"):
# Datetime-like object - extract components
try:
interaction_time = datetime(
year=timestamp_value.year,
month=timestamp_value.month,
day=timestamp_value.day,
hour=getattr(timestamp_value, 'hour', 0),
minute=getattr(timestamp_value, 'minute', 0),
second=getattr(timestamp_value, 'second', 0),
microsecond=getattr(timestamp_value, 'microsecond', 0)
hour=getattr(timestamp_value, "hour", 0),
minute=getattr(timestamp_value, "minute", 0),
second=getattr(timestamp_value, "second", 0),
microsecond=getattr(timestamp_value, "microsecond", 0),
)
except (AttributeError, ValueError):
pass
if interaction_time is None:
# Last resort: try converting to string and parsing
str_value = str(timestamp_value)
@ -110,73 +112,83 @@ async def extract_usage_frequency(
interaction_time = datetime.fromtimestamp(ts_int)
else:
interaction_time = datetime.fromisoformat(str_value)
if interaction_time is None:
raise ValueError(f"Could not parse timestamp: {timestamp_value}")
# Make sure it's timezone-naive for comparison
if interaction_time.tzinfo is not None:
interaction_time = interaction_time.replace(tzinfo=None)
interaction_nodes[node_id] = {
'node': node,
'timestamp': interaction_time,
'in_window': interaction_time >= cutoff_time
"node": node,
"timestamp": interaction_time,
"in_window": interaction_time >= cutoff_time,
}
interaction_count += 1
if interaction_time >= cutoff_time:
interactions_in_window += 1
except (ValueError, TypeError, AttributeError, OSError) as e:
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}")
logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}")
logger.warning(
f"Failed to parse timestamp for interaction node {node_id}: {e}"
)
logger.debug(
f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}"
)
# Process edges to find graph elements used in interactions
for edge in subgraph.edges:
relationship_type = edge.attributes.get('relationship_type')
relationship_type = edge.attributes.get("relationship_type")
# Look for 'used_graph_element_to_answer' edges
if relationship_type == 'used_graph_element_to_answer':
if relationship_type == "used_graph_element_to_answer":
# node1 should be the CogneeUserInteraction, node2 is the graph element
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# Check if source is an interaction node in our time window
if source_id in interaction_nodes:
interaction_data = interaction_nodes[source_id]
if interaction_data['in_window']:
if interaction_data["in_window"]:
# Count the graph element (target node) being used
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
# Also track what type of element it is for analytics
target_node = subgraph.get_node(target_id)
if target_node:
element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type')
element_type = target_node.attributes.get(
"type"
) or target_node.attributes.get("node_type")
if element_type:
relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1
relationship_type_frequencies[element_type] = (
relationship_type_frequencies.get(element_type, 0) + 1
)
# Also track general edge usage patterns
elif relationship_type and relationship_type != 'used_graph_element_to_answer':
elif relationship_type and relationship_type != "used_graph_element_to_answer":
# Check if either endpoint is referenced in a recent interaction
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# If this edge connects to any frequently accessed nodes, track the edge type
if source_id in node_frequencies or target_id in node_frequencies:
edge_key = f"{relationship_type}:{source_id}:{target_id}"
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
# Filter frequencies above threshold
filtered_node_frequencies = {
node_id: freq for node_id, freq in node_frequencies.items()
node_id: freq
for node_id, freq in node_frequencies.items()
if freq >= min_interaction_threshold
}
filtered_edge_frequencies = {
edge_key: freq for edge_key, freq in edge_frequencies.items()
edge_key: freq
for edge_key, freq in edge_frequencies.items()
if freq >= min_interaction_threshold
}
logger.info(
f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
)
@ -185,58 +197,59 @@ async def extract_usage_frequency(
f"above threshold (min: {min_interaction_threshold})"
)
logger.info(f"Element type distribution: {relationship_type_frequencies}")
return {
'node_frequencies': filtered_node_frequencies,
'edge_frequencies': filtered_edge_frequencies,
'element_type_frequencies': relationship_type_frequencies,
'total_interactions': interaction_count,
'interactions_in_window': interactions_in_window,
'time_window_days': time_window.days,
'last_processed_timestamp': current_time.isoformat(),
'cutoff_timestamp': cutoff_time.isoformat()
"node_frequencies": filtered_node_frequencies,
"edge_frequencies": filtered_edge_frequencies,
"element_type_frequencies": relationship_type_frequencies,
"total_interactions": interaction_count,
"interactions_in_window": interactions_in_window,
"time_window_days": time_window.days,
"last_processed_timestamp": current_time.isoformat(),
"cutoff_timestamp": cutoff_time.isoformat(),
}
async def add_frequency_weights(
graph_adapter: GraphDBInterface,
usage_frequencies: Dict[str, Any]
graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any]
) -> None:
"""
Add frequency weights to graph nodes and edges using the graph adapter.
Uses direct Cypher queries for Neo4j adapter compatibility.
Writes frequency_weight properties back to the graph for use in:
- Ranking frequently referenced entities higher during retrieval
- Adjusting scoring for completion strategies
- Exposing usage metrics in dashboards or audits
:param graph_adapter: Graph database adapter interface
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
"""
node_frequencies = usage_frequencies.get('node_frequencies', {})
edge_frequencies = usage_frequencies.get('edge_frequencies', {})
node_frequencies = usage_frequencies.get("node_frequencies", {})
edge_frequencies = usage_frequencies.get("edge_frequencies", {})
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
# Check adapter type and use appropriate method
adapter_type = type(graph_adapter).__name__
logger.info(f"Using adapter: {adapter_type}")
nodes_updated = 0
nodes_failed = 0
# Determine which method to use based on adapter type
use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query')
use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query')
use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties')
use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query")
use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query")
use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr(
graph_adapter, "update_node_properties"
)
# Method 1: Neo4j Cypher with SET (creates properties on the fly)
if use_neo4j_cypher:
try:
logger.info("Using Neo4j Cypher SET method")
last_updated = usage_frequencies.get('last_processed_timestamp')
last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items():
try:
query = """
@ -246,47 +259,49 @@ async def add_frequency_weights(
n.frequency_updated_at = $updated_at
RETURN n.id as id
"""
result = await graph_adapter.query(
query,
params={
'node_id': node_id,
'frequency': frequency,
'updated_at': last_updated
}
"node_id": node_id,
"frequency": frequency,
"updated_at": last_updated,
},
)
if result and len(result) > 0:
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found or not updated")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
except Exception as e:
logger.error(f"Neo4j Cypher update failed: {e}")
use_neo4j_cypher = False
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'):
elif (
use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node")
):
logger.info("Using Kuzu get_node + add_node method")
last_updated = usage_frequencies.get('last_processed_timestamp')
last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items():
try:
# Get the existing node (returns a dict)
existing_node_dict = await graph_adapter.get_node(node_id)
if existing_node_dict:
# Update the dict with new properties
existing_node_dict['frequency_weight'] = frequency
existing_node_dict['frequency_updated_at'] = last_updated
existing_node_dict["frequency_weight"] = frequency
existing_node_dict["frequency_updated_at"] = last_updated
# Kuzu's add_node likely just takes the dict directly, not a Node object
# Try passing the dict directly first
try:
@ -295,20 +310,21 @@ async def add_frequency_weights(
except Exception as dict_error:
# If dict doesn't work, try creating a Node object
logger.debug(f"Dict add failed, trying Node object: {dict_error}")
try:
from cognee.infrastructure.engine import Node
# Try different Node constructor patterns
try:
# Pattern 1: Just properties
node_obj = Node(existing_node_dict)
except:
except Exception:
# Pattern 2: Type and properties
node_obj = Node(
type=existing_node_dict.get('type', 'Unknown'),
**existing_node_dict
type=existing_node_dict.get("type", "Unknown"),
**existing_node_dict,
)
await graph_adapter.add_node(node_obj)
nodes_updated += 1
except Exception as node_error:
@ -317,13 +333,13 @@ async def add_frequency_weights(
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
# Method 3: Generic get_node_by_id + update_node_properties
elif use_get_update:
logger.info("Using get/update method for adapter")
@ -331,90 +347,95 @@ async def add_frequency_weights(
try:
# Get current node data
node_data = await graph_adapter.get_node_by_id(node_id)
if node_data:
# Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict):
properties = node_data.get('properties', {})
properties = node_data.get("properties", {})
else:
properties = getattr(node_data, 'properties', {}) or {}
properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight
properties['frequency_weight'] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
properties["frequency_weight"] = frequency
properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter
await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
for node_id, frequency in node_frequencies.items():
try:
# Get current node data
node_data = await graph_adapter.get_node_by_id(node_id)
if node_data:
# Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict):
properties = node_data.get('properties', {})
properties = node_data.get("properties", {})
else:
properties = getattr(node_data, 'properties', {}) or {}
properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight
properties['frequency_weight'] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
properties["frequency_weight"] = frequency
properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter
await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
# If no method is available
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
logger.error(f"Adapter {adapter_type} does not support required update methods")
logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'")
logger.error(
"Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'"
)
return
# Update edge frequencies
# Note: Edge property updates are backend-specific
if edge_frequencies:
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
edges_updated = 0
edges_failed = 0
for edge_key, frequency in edge_frequencies.items():
try:
# Parse edge key: "relationship_type:source_id:target_id"
parts = edge_key.split(':', 2)
parts = edge_key.split(":", 2)
if len(parts) == 3:
relationship_type, source_id, target_id = parts
# Try to update edge if adapter supports it
if hasattr(graph_adapter, 'update_edge_properties'):
if hasattr(graph_adapter, "update_edge_properties"):
edge_properties = {
'frequency_weight': frequency,
'frequency_updated_at': usage_frequencies.get('last_processed_timestamp')
"frequency_weight": frequency,
"frequency_updated_at": usage_frequencies.get(
"last_processed_timestamp"
),
}
await graph_adapter.update_edge_properties(
source_id,
target_id,
relationship_type,
edge_properties
source_id, target_id, relationship_type, edge_properties
)
edges_updated += 1
else:
@ -423,28 +444,28 @@ async def add_frequency_weights(
f"Adapter doesn't support update_edge_properties for "
f"{relationship_type} ({source_id} -> {target_id})"
)
except Exception as e:
logger.error(f"Error updating edge {edge_key}: {e}")
edges_failed += 1
if edges_updated > 0:
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
else:
logger.info(
"Edge frequency updates skipped (adapter may not support edge property updates)"
)
# Store aggregate statistics as metadata if supported
if hasattr(graph_adapter, 'set_metadata'):
if hasattr(graph_adapter, "set_metadata"):
try:
metadata = {
'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}),
'total_interactions': usage_frequencies.get('total_interactions', 0),
'interactions_in_window': usage_frequencies.get('interactions_in_window', 0),
'last_frequency_update': usage_frequencies.get('last_processed_timestamp')
"element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}),
"total_interactions": usage_frequencies.get("total_interactions", 0),
"interactions_in_window": usage_frequencies.get("interactions_in_window", 0),
"last_frequency_update": usage_frequencies.get("last_processed_timestamp"),
}
await graph_adapter.set_metadata('usage_frequency_stats', metadata)
await graph_adapter.set_metadata("usage_frequency_stats", metadata)
logger.info("Stored usage frequency statistics as metadata")
except Exception as e:
logger.warning(f"Could not store usage statistics as metadata: {e}")
@ -454,25 +475,25 @@ async def create_usage_frequency_pipeline(
graph_adapter: GraphDBInterface,
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1,
batch_size: int = 100
batch_size: int = 100,
) -> tuple:
"""
Create memify pipeline entry for usage frequency tracking.
This follows the same pattern as feedback enrichment flows, allowing
the frequency update to run end-to-end in a custom memify pipeline.
Use case example:
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=my_adapter,
time_window=timedelta(days=30),
min_interaction_threshold=2
)
# Run in memify pipeline
pipeline = Pipeline(extraction_tasks + enrichment_tasks)
results = await pipeline.run()
:param graph_adapter: Graph database adapter
:param time_window: Time window for counting interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
@ -481,23 +502,23 @@ async def create_usage_frequency_pipeline(
"""
logger.info("Creating usage frequency pipeline")
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
extraction_tasks = [
Task(
extract_usage_frequency,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold
min_interaction_threshold=min_interaction_threshold,
)
]
enrichment_tasks = [
Task(
add_frequency_weights,
graph_adapter=graph_adapter,
task_config={"batch_size": batch_size}
task_config={"batch_size": batch_size},
)
]
return extraction_tasks, enrichment_tasks
@ -505,21 +526,21 @@ async def run_usage_frequency_update(
graph_adapter: GraphDBInterface,
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1
min_interaction_threshold: int = 1,
) -> Dict[str, Any]:
"""
Convenience function to run the complete usage frequency update pipeline.
This is the main entry point for updating frequency weights on graph elements
based on CogneeUserInteraction data from cognee.search(save_interaction=True).
Example usage:
# After running searches with save_interaction=True
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
# Get the graph with interactions
graph = await get_cognee_graph_with_interactions()
# Update frequency weights
stats = await run_usage_frequency_update(
graph_adapter=graph_adapter,
@ -527,9 +548,9 @@ async def run_usage_frequency_update(
time_window=timedelta(days=30), # Last 30 days
min_interaction_threshold=2 # At least 2 uses
)
print(f"Updated {len(stats['node_frequencies'])} nodes")
:param graph_adapter: Graph database adapter
:param subgraphs: List of CogneeGraph instances with interaction data
:param time_window: Time window for counting interactions
@ -537,51 +558,48 @@ async def run_usage_frequency_update(
:return: Usage frequency statistics
"""
logger.info("Starting usage frequency update")
try:
# Extract frequencies from interaction data
usage_frequencies = await extract_usage_frequency(
subgraphs=subgraphs,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold
min_interaction_threshold=min_interaction_threshold,
)
# Add frequency weights back to the graph
await add_frequency_weights(
graph_adapter=graph_adapter,
usage_frequencies=usage_frequencies
graph_adapter=graph_adapter, usage_frequencies=usage_frequencies
)
logger.info("Usage frequency update completed successfully")
logger.info(
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
)
return usage_frequencies
except Exception as e:
logger.error(f"Error during usage frequency update: {str(e)}")
raise
async def get_most_frequent_elements(
graph_adapter: GraphDBInterface,
top_n: int = 10,
element_type: Optional[str] = None
graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Retrieve the most frequently accessed graph elements.
Useful for analytics dashboards and understanding user behavior.
:param graph_adapter: Graph database adapter
:param top_n: Number of top elements to return
:param element_type: Optional filter by element type
:return: List of elements with their frequency weights
"""
logger.info(f"Retrieving top {top_n} most frequent elements")
# This would need to be implemented based on the specific graph adapter's query capabilities
# Pseudocode:
# results = await graph_adapter.query_nodes_by_property(
@ -590,6 +608,6 @@ async def get_most_frequent_elements(
# limit=top_n,
# filters={'type': element_type} if element_type else None
# )
logger.warning("get_most_frequent_elements needs adapter-specific implementation")
return []
return []

View file

@ -6,7 +6,7 @@ Tests cover extraction logic, adapter integration, edge cases, and end-to-end wo
Run with:
pytest test_usage_frequency_comprehensive.py -v
Or without pytest:
python test_usage_frequency_comprehensive.py
"""
@ -23,8 +23,9 @@ try:
from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency,
add_frequency_weights,
run_usage_frequency_update
run_usage_frequency_update,
)
COGNEE_AVAILABLE = True
except ImportError:
COGNEE_AVAILABLE = False
@ -33,16 +34,16 @@ except ImportError:
class TestUsageFrequencyExtraction(unittest.TestCase):
"""Test the core frequency extraction logic."""
def setUp(self):
"""Set up test fixtures."""
if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available")
def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5):
"""Create a mock graph with interactions and elements."""
graph = CogneeGraph()
# Create interaction nodes
current_time = datetime.now()
for i in range(num_interactions):
@ -50,25 +51,22 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id=f"interaction_{i}",
node_type="CogneeUserInteraction",
attributes={
'type': 'CogneeUserInteraction',
'query_text': f'Test query {i}',
'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000)
}
"type": "CogneeUserInteraction",
"query_text": f"Test query {i}",
"timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000),
},
)
graph.add_node(interaction_node)
# Create graph element nodes
for i in range(num_elements):
element_node = Node(
id=f"element_{i}",
node_type="DocumentChunk",
attributes={
'type': 'DocumentChunk',
'text': f'Element content {i}'
}
attributes={"type": "DocumentChunk", "text": f"Element content {i}"},
)
graph.add_node(element_node)
# Create usage edges (interactions reference elements)
for i in range(num_interactions):
# Each interaction uses 2-3 elements
@ -78,183 +76,179 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
node1=graph.get_node(f"interaction_{i}"),
node2=graph.get_node(f"element_{element_idx}"),
edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'}
attributes={"relationship_type": "used_graph_element_to_answer"},
)
graph.add_edge(edge)
return graph
async def test_basic_frequency_extraction(self):
"""Test basic frequency extraction with simple graph."""
graph = self.create_mock_graph(num_interactions=3, num_elements=5)
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7),
min_interaction_threshold=1
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
)
self.assertIn('node_frequencies', result)
self.assertIn('total_interactions', result)
self.assertEqual(result['total_interactions'], 3)
self.assertGreater(len(result['node_frequencies']), 0)
self.assertIn("node_frequencies", result)
self.assertIn("total_interactions", result)
self.assertEqual(result["total_interactions"], 3)
self.assertGreater(len(result["node_frequencies"]), 0)
async def test_time_window_filtering(self):
"""Test that time window correctly filters old interactions."""
graph = CogneeGraph()
current_time = datetime.now()
# Add recent interaction (within window)
recent_node = Node(
id="recent_interaction",
node_type="CogneeUserInteraction",
attributes={
'type': 'CogneeUserInteraction',
'timestamp': int(current_time.timestamp() * 1000)
}
"type": "CogneeUserInteraction",
"timestamp": int(current_time.timestamp() * 1000),
},
)
graph.add_node(recent_node)
# Add old interaction (outside window)
old_node = Node(
id="old_interaction",
node_type="CogneeUserInteraction",
attributes={
'type': 'CogneeUserInteraction',
'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000)
}
"type": "CogneeUserInteraction",
"timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000),
},
)
graph.add_node(old_node)
# Add element
element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'})
element = Node(
id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}
)
graph.add_node(element)
# Add edges
graph.add_edge(Edge(
node1=recent_node, node2=element,
edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'}
))
graph.add_edge(Edge(
node1=old_node, node2=element,
edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'}
))
graph.add_edge(
Edge(
node1=recent_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
graph.add_edge(
Edge(
node1=old_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
# Extract with 7-day window
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7),
min_interaction_threshold=1
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
)
# Should only count recent interaction
self.assertEqual(result['interactions_in_window'], 1)
self.assertEqual(result['total_interactions'], 2)
self.assertEqual(result["interactions_in_window"], 1)
self.assertEqual(result["total_interactions"], 2)
async def test_threshold_filtering(self):
"""Test that minimum threshold filters low-frequency nodes."""
graph = self.create_mock_graph(num_interactions=5, num_elements=10)
# Extract with threshold of 3
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7),
min_interaction_threshold=3
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3
)
# Only nodes with 3+ accesses should be included
for node_id, freq in result['node_frequencies'].items():
for node_id, freq in result["node_frequencies"].items():
self.assertGreaterEqual(freq, 3)
async def test_element_type_tracking(self):
"""Test that element types are properly tracked."""
graph = CogneeGraph()
# Create interaction
interaction = Node(
id="interaction_1",
node_type="CogneeUserInteraction",
attributes={
'type': 'CogneeUserInteraction',
'timestamp': int(datetime.now().timestamp() * 1000)
}
"type": "CogneeUserInteraction",
"timestamp": int(datetime.now().timestamp() * 1000),
},
)
graph.add_node(interaction)
# Create elements of different types
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'})
entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'})
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"})
entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"})
graph.add_node(chunk)
graph.add_node(entity)
# Add edges
for element in [chunk, entity]:
graph.add_edge(Edge(
node1=interaction, node2=element,
edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'}
))
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7)
)
graph.add_edge(
Edge(
node1=interaction,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
# Check element types were tracked
self.assertIn('element_type_frequencies', result)
types = result['element_type_frequencies']
self.assertIn('DocumentChunk', types)
self.assertIn('Entity', types)
self.assertIn("element_type_frequencies", result)
types = result["element_type_frequencies"]
self.assertIn("DocumentChunk", types)
self.assertIn("Entity", types)
async def test_empty_graph(self):
"""Test handling of empty graph."""
graph = CogneeGraph()
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7)
)
self.assertEqual(result['total_interactions'], 0)
self.assertEqual(len(result['node_frequencies']), 0)
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
self.assertEqual(result["total_interactions"], 0)
self.assertEqual(len(result["node_frequencies"]), 0)
async def test_no_interactions_in_window(self):
"""Test handling when all interactions are outside time window."""
graph = CogneeGraph()
# Add old interaction
old_time = datetime.now() - timedelta(days=30)
old_interaction = Node(
id="old_interaction",
node_type="CogneeUserInteraction",
attributes={
'type': 'CogneeUserInteraction',
'timestamp': int(old_time.timestamp() * 1000)
}
"type": "CogneeUserInteraction",
"timestamp": int(old_time.timestamp() * 1000),
},
)
graph.add_node(old_interaction)
result = await extract_usage_frequency(
subgraphs=[graph],
time_window=timedelta(days=7)
)
self.assertEqual(result['interactions_in_window'], 0)
self.assertEqual(result['total_interactions'], 1)
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
self.assertEqual(result["interactions_in_window"], 0)
self.assertEqual(result["total_interactions"], 1)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete workflow."""
def setUp(self):
"""Set up test fixtures."""
if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available")
async def test_end_to_end_workflow(self):
"""Test the complete end-to-end frequency tracking workflow."""
# This would require a full Cognee setup with database
@ -266,6 +260,7 @@ class TestIntegration(unittest.TestCase):
# Test Runner
# ============================================================================
def run_async_test(test_func):
"""Helper to run async test functions."""
asyncio.run(test_func())
@ -277,24 +272,24 @@ def main():
print("⚠ Cognee not available - skipping tests")
print("Install with: pip install cognee[neo4j]")
return
print("=" * 80)
print("Running Usage Frequency Tests")
print("=" * 80)
print()
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# Add tests
suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction))
suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Summary
print()
print("=" * 80)
@ -305,9 +300,9 @@ def main():
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped)}")
return 0 if result.wasSuccessful() else 1
if __name__ == "__main__":
exit(main())
exit(main())

View file

@ -39,10 +39,11 @@ load_dotenv()
# STEP 1: Setup and Configuration
# ============================================================================
async def setup_knowledge_base():
"""
Create a fresh knowledge base with sample content.
In a real application, you would:
- Load documents from files, databases, or APIs
- Process larger datasets
@ -51,13 +52,13 @@ async def setup_knowledge_base():
print("=" * 80)
print("STEP 1: Setting up knowledge base")
print("=" * 80)
# Reset state for clean demo (optional in production)
print("\nResetting Cognee state...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("✓ Reset complete")
# Sample content: AI/ML educational material
documents = [
"""
@ -87,16 +88,16 @@ async def setup_knowledge_base():
recognition, object detection, and image segmentation tasks.
""",
]
print(f"\nAdding {len(documents)} documents to knowledge base...")
await cognee.add(documents, dataset_name="ai_ml_fundamentals")
print("✓ Documents added")
# Build knowledge graph
print("\nBuilding knowledge graph (cognify)...")
await cognee.cognify()
print("✓ Knowledge graph built")
print("\n" + "=" * 80)
@ -104,26 +105,27 @@ async def setup_knowledge_base():
# STEP 2: Simulate User Searches with Interaction Tracking
# ============================================================================
async def simulate_user_searches(queries: List[str]):
"""
Simulate users searching the knowledge base.
The key parameter is save_interaction=True, which creates:
- CogneeUserInteraction nodes (one per search)
- used_graph_element_to_answer edges (connecting queries to relevant nodes)
Args:
queries: List of search queries to simulate
Returns:
Number of successful searches
"""
print("=" * 80)
print("STEP 2: Simulating user searches with interaction tracking")
print("=" * 80)
successful_searches = 0
for i, query in enumerate(queries, 1):
print(f"\nSearch {i}/{len(queries)}: '{query}'")
try:
@ -131,20 +133,20 @@ async def simulate_user_searches(queries: List[str]):
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
save_interaction=True, # ← THIS IS CRITICAL!
top_k=5
top_k=5,
)
successful_searches += 1
# Show snippet of results
result_preview = str(results)[:100] if results else "No results"
print(f" ✓ Completed ({result_preview}...)")
except Exception as e:
print(f" ✗ Failed: {e}")
print(f"\n✓ Completed {successful_searches}/{len(queries)} searches")
print("=" * 80)
return successful_searches
@ -152,71 +154,80 @@ async def simulate_user_searches(queries: List[str]):
# STEP 3: Extract and Apply Usage Frequencies
# ============================================================================
async def extract_and_apply_frequencies(
time_window_days: int = 7,
min_threshold: int = 1
time_window_days: int = 7, min_threshold: int = 1
) -> Dict[str, Any]:
"""
Extract usage frequencies from interactions and apply them to the graph.
This function:
1. Retrieves the graph with interaction data
2. Counts how often each node was accessed
3. Writes frequency_weight property back to nodes
Args:
time_window_days: Only count interactions from last N days
min_threshold: Minimum accesses to track (filter out rarely used nodes)
Returns:
Dictionary with statistics about the frequency update
"""
print("=" * 80)
print("STEP 3: Extracting and applying usage frequencies")
print("=" * 80)
# Get graph adapter
graph_engine = await get_graph_engine()
# Retrieve graph with interactions
print("\nRetrieving graph from database...")
graph = CogneeGraph()
await graph.project_graph_from_db(
adapter=graph_engine,
node_properties_to_project=[
"type", "node_type", "timestamp", "created_at",
"text", "name", "query_text", "frequency_weight"
"type",
"node_type",
"timestamp",
"created_at",
"text",
"name",
"query_text",
"frequency_weight",
],
edge_properties_to_project=["relationship_type", "timestamp"],
directed=True,
)
print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
# Count interaction nodes
interaction_nodes = [
n for n in graph.nodes.values()
if n.attributes.get('type') == 'CogneeUserInteraction' or
n.attributes.get('node_type') == 'CogneeUserInteraction'
n
for n in graph.nodes.values()
if n.attributes.get("type") == "CogneeUserInteraction"
or n.attributes.get("node_type") == "CogneeUserInteraction"
]
print(f"✓ Found {len(interaction_nodes)} interaction nodes")
# Run frequency extraction and update
print(f"\nExtracting frequencies (time window: {time_window_days} days)...")
stats = await run_usage_frequency_update(
graph_adapter=graph_engine,
subgraphs=[graph],
time_window=timedelta(days=time_window_days),
min_interaction_threshold=min_threshold
min_interaction_threshold=min_threshold,
)
print("\n✓ Frequency extraction complete!")
print(
f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}"
)
print(f"\n✓ Frequency extraction complete!")
print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}")
print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
print("=" * 80)
return stats
@ -224,33 +235,30 @@ async def extract_and_apply_frequencies(
# STEP 4: Analyze and Display Results
# ============================================================================
async def analyze_results(stats: Dict[str, Any]):
"""
Analyze and display the frequency tracking results.
Shows:
- Top most frequently accessed nodes
- Element type distribution
- Verification that weights were written to database
Args:
stats: Statistics from frequency extraction
"""
print("=" * 80)
print("STEP 4: Analyzing usage frequency results")
print("=" * 80)
# Display top nodes by frequency
if stats['node_frequencies']:
if stats["node_frequencies"]:
print("\n📊 Top 10 Most Frequently Accessed Elements:")
print("-" * 80)
sorted_nodes = sorted(
stats['node_frequencies'].items(),
key=lambda x: x[1],
reverse=True
)
sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True)
# Get graph to display node details
graph_engine = await get_graph_engine()
graph = CogneeGraph()
@ -260,48 +268,48 @@ async def analyze_results(stats: Dict[str, Any]):
edge_properties_to_project=[],
directed=True,
)
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
node = graph.get_node(node_id)
if node:
node_type = node.attributes.get('type', 'Unknown')
text = node.attributes.get('text') or node.attributes.get('name') or ''
node_type = node.attributes.get("type", "Unknown")
text = node.attributes.get("text") or node.attributes.get("name") or ""
text_preview = text[:60] + "..." if len(text) > 60 else text
print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Type: {node_type}")
print(f" Content: {text_preview}")
else:
print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Node ID: {node_id[:50]}...")
# Display element type distribution
if stats.get('element_type_frequencies'):
if stats.get("element_type_frequencies"):
print("\n\n📈 Element Type Distribution:")
print("-" * 80)
type_dist = stats['element_type_frequencies']
type_dist = stats["element_type_frequencies"]
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
print(f" {elem_type}: {count} accesses")
# Verify weights in database (Neo4j only)
print("\n\n🔍 Verifying weights in database...")
print("-" * 80)
graph_engine = await get_graph_engine()
adapter_type = type(graph_engine).__name__
if adapter_type == 'Neo4jAdapter':
if adapter_type == "Neo4jAdapter":
try:
result = await graph_engine.query("""
MATCH (n)
WHERE n.frequency_weight IS NOT NULL
RETURN count(n) as weighted_count
""")
count = result[0]['weighted_count'] if result else 0
count = result[0]["weighted_count"] if result else 0
if count > 0:
print(f"{count} nodes have frequency_weight in Neo4j database")
# Show sample
sample = await graph_engine.query("""
MATCH (n)
@ -310,7 +318,7 @@ async def analyze_results(stats: Dict[str, Any]):
ORDER BY n.frequency_weight DESC
LIMIT 3
""")
print("\nSample weighted nodes:")
for row in sample:
print(f" - Weight: {row['weight']}, Type: {row['labels']}")
@ -320,7 +328,7 @@ async def analyze_results(stats: Dict[str, Any]):
print(f"Could not verify in Neo4j: {e}")
else:
print(f"Database verification not implemented for {adapter_type}")
print("\n" + "=" * 80)
@ -328,10 +336,11 @@ async def analyze_results(stats: Dict[str, Any]):
# STEP 5: Demonstrate Usage in Retrieval
# ============================================================================
async def demonstrate_retrieval_usage():
"""
Demonstrate how frequency weights can be used in retrieval.
Note: This is a conceptual demonstration. To actually use frequency
weights in ranking, you would need to modify the retrieval/completion
strategies to incorporate the frequency_weight property.
@ -339,39 +348,39 @@ async def demonstrate_retrieval_usage():
print("=" * 80)
print("STEP 5: How to use frequency weights in retrieval")
print("=" * 80)
print("""
Frequency weights can be used to improve search results:
1. RANKING BOOST:
- Multiply relevance scores by frequency_weight
- Prioritize frequently accessed nodes in results
2. COMPLETION STRATEGIES:
- Adjust triplet importance based on usage
- Filter out rarely accessed information
3. ANALYTICS:
- Track trending topics over time
- Understand user interests and behavior
- Identify knowledge gaps (low-frequency nodes)
4. ADAPTIVE RETRIEVAL:
- Personalize results based on team usage patterns
- Surface popular answers faster
Example Cypher query with frequency boost (Neo4j):
MATCH (n)
WHERE n.text CONTAINS $search_term
RETURN n, n.frequency_weight as boost
ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC
LIMIT 10
To integrate this into Cognee, you would modify the completion
strategy to include frequency_weight in the scoring function.
""")
print("=" * 80)
@ -379,6 +388,7 @@ async def demonstrate_retrieval_usage():
# MAIN: Run Complete Example
# ============================================================================
async def main():
"""
Run the complete end-to-end usage frequency tracking example.
@ -390,25 +400,25 @@ async def main():
print("" + " " * 78 + "")
print("" + "=" * 78 + "")
print("\n")
# Configuration check
print("Configuration:")
print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}")
print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}")
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
# Verify LLM key is set
if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here':
if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here":
print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
print(" Set your API key to run searches")
return
print("\n")
try:
# Step 1: Setup
await setup_knowledge_base()
# Step 2: Simulate searches
# Note: Repeat queries increase frequency for those topics
queries = [
@ -422,25 +432,22 @@ async def main():
"What is reinforcement learning?",
"Tell me more about neural networks", # Third repeat
]
successful_searches = await simulate_user_searches(queries)
if successful_searches == 0:
print("⚠ No searches completed - cannot demonstrate frequency tracking")
return
# Step 3: Extract frequencies
stats = await extract_and_apply_frequencies(
time_window_days=7,
min_threshold=1
)
stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1)
# Step 4: Analyze results
await analyze_results(stats)
# Step 5: Show usage examples
await demonstrate_retrieval_usage()
# Summary
print("\n")
print("" + "=" * 78 + "")
@ -449,26 +456,27 @@ async def main():
print("" + " " * 78 + "")
print("" + "=" * 78 + "")
print("\n")
print("Summary:")
print(f" ✓ Documents added: 4")
print(" ✓ Documents added: 4")
print(f" ✓ Searches performed: {successful_searches}")
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
print("\nNext steps:")
print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph")
print(" 2. Modify retrieval strategies to use frequency_weight")
print(" 3. Build analytics dashboards using element_type_frequencies")
print(" 4. Run periodic frequency updates to track trends over time")
print("\n")
except Exception as e:
print(f"\n✗ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())