chore: ruff format and refactor on contributor PR

This commit is contained in:
Igor Ilic 2026-01-13 15:10:21 +01:00
parent 0d2f66fa1d
commit dce51efbe3
3 changed files with 408 additions and 387 deletions

View file

@ -12,7 +12,7 @@ logger = get_logger("extract_usage_frequency")
async def extract_usage_frequency( async def extract_usage_frequency(
subgraphs: List[CogneeGraph], subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1 min_interaction_threshold: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Extract usage frequency from CogneeUserInteraction nodes. Extract usage frequency from CogneeUserInteraction nodes.
@ -48,11 +48,13 @@ async def extract_usage_frequency(
# Find all CogneeUserInteraction nodes # Find all CogneeUserInteraction nodes
interaction_nodes = {} interaction_nodes = {}
for node_id, node in subgraph.nodes.items(): for node_id, node in subgraph.nodes.items():
node_type = node.attributes.get('type') or node.attributes.get('node_type') node_type = node.attributes.get("type") or node.attributes.get("node_type")
if node_type == 'CogneeUserInteraction': if node_type == "CogneeUserInteraction":
# Parse and validate timestamp # Parse and validate timestamp
timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at') timestamp_value = node.attributes.get("timestamp") or node.attributes.get(
"created_at"
)
if timestamp_value is not None: if timestamp_value is not None:
try: try:
# Handle various timestamp formats # Handle various timestamp formats
@ -81,20 +83,20 @@ async def extract_usage_frequency(
else: else:
# ISO format string # ISO format string
interaction_time = datetime.fromisoformat(timestamp_value) interaction_time = datetime.fromisoformat(timestamp_value)
elif hasattr(timestamp_value, 'to_native'): elif hasattr(timestamp_value, "to_native"):
# Neo4j datetime object - convert to Python datetime # Neo4j datetime object - convert to Python datetime
interaction_time = timestamp_value.to_native() interaction_time = timestamp_value.to_native()
elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'): elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"):
# Datetime-like object - extract components # Datetime-like object - extract components
try: try:
interaction_time = datetime( interaction_time = datetime(
year=timestamp_value.year, year=timestamp_value.year,
month=timestamp_value.month, month=timestamp_value.month,
day=timestamp_value.day, day=timestamp_value.day,
hour=getattr(timestamp_value, 'hour', 0), hour=getattr(timestamp_value, "hour", 0),
minute=getattr(timestamp_value, 'minute', 0), minute=getattr(timestamp_value, "minute", 0),
second=getattr(timestamp_value, 'second', 0), second=getattr(timestamp_value, "second", 0),
microsecond=getattr(timestamp_value, 'microsecond', 0) microsecond=getattr(timestamp_value, "microsecond", 0),
) )
except (AttributeError, ValueError): except (AttributeError, ValueError):
pass pass
@ -119,23 +121,27 @@ async def extract_usage_frequency(
interaction_time = interaction_time.replace(tzinfo=None) interaction_time = interaction_time.replace(tzinfo=None)
interaction_nodes[node_id] = { interaction_nodes[node_id] = {
'node': node, "node": node,
'timestamp': interaction_time, "timestamp": interaction_time,
'in_window': interaction_time >= cutoff_time "in_window": interaction_time >= cutoff_time,
} }
interaction_count += 1 interaction_count += 1
if interaction_time >= cutoff_time: if interaction_time >= cutoff_time:
interactions_in_window += 1 interactions_in_window += 1
except (ValueError, TypeError, AttributeError, OSError) as e: except (ValueError, TypeError, AttributeError, OSError) as e:
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}") logger.warning(
logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}") f"Failed to parse timestamp for interaction node {node_id}: {e}"
)
logger.debug(
f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}"
)
# Process edges to find graph elements used in interactions # Process edges to find graph elements used in interactions
for edge in subgraph.edges: for edge in subgraph.edges:
relationship_type = edge.attributes.get('relationship_type') relationship_type = edge.attributes.get("relationship_type")
# Look for 'used_graph_element_to_answer' edges # Look for 'used_graph_element_to_answer' edges
if relationship_type == 'used_graph_element_to_answer': if relationship_type == "used_graph_element_to_answer":
# node1 should be the CogneeUserInteraction, node2 is the graph element # node1 should be the CogneeUserInteraction, node2 is the graph element
source_id = str(edge.node1.id) source_id = str(edge.node1.id)
target_id = str(edge.node2.id) target_id = str(edge.node2.id)
@ -144,19 +150,23 @@ async def extract_usage_frequency(
if source_id in interaction_nodes: if source_id in interaction_nodes:
interaction_data = interaction_nodes[source_id] interaction_data = interaction_nodes[source_id]
if interaction_data['in_window']: if interaction_data["in_window"]:
# Count the graph element (target node) being used # Count the graph element (target node) being used
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1 node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
# Also track what type of element it is for analytics # Also track what type of element it is for analytics
target_node = subgraph.get_node(target_id) target_node = subgraph.get_node(target_id)
if target_node: if target_node:
element_type = target_node.attributes.get('type') or target_node.attributes.get('node_type') element_type = target_node.attributes.get(
"type"
) or target_node.attributes.get("node_type")
if element_type: if element_type:
relationship_type_frequencies[element_type] = relationship_type_frequencies.get(element_type, 0) + 1 relationship_type_frequencies[element_type] = (
relationship_type_frequencies.get(element_type, 0) + 1
)
# Also track general edge usage patterns # Also track general edge usage patterns
elif relationship_type and relationship_type != 'used_graph_element_to_answer': elif relationship_type and relationship_type != "used_graph_element_to_answer":
# Check if either endpoint is referenced in a recent interaction # Check if either endpoint is referenced in a recent interaction
source_id = str(edge.node1.id) source_id = str(edge.node1.id)
target_id = str(edge.node2.id) target_id = str(edge.node2.id)
@ -168,12 +178,14 @@ async def extract_usage_frequency(
# Filter frequencies above threshold # Filter frequencies above threshold
filtered_node_frequencies = { filtered_node_frequencies = {
node_id: freq for node_id, freq in node_frequencies.items() node_id: freq
for node_id, freq in node_frequencies.items()
if freq >= min_interaction_threshold if freq >= min_interaction_threshold
} }
filtered_edge_frequencies = { filtered_edge_frequencies = {
edge_key: freq for edge_key, freq in edge_frequencies.items() edge_key: freq
for edge_key, freq in edge_frequencies.items()
if freq >= min_interaction_threshold if freq >= min_interaction_threshold
} }
@ -187,20 +199,19 @@ async def extract_usage_frequency(
logger.info(f"Element type distribution: {relationship_type_frequencies}") logger.info(f"Element type distribution: {relationship_type_frequencies}")
return { return {
'node_frequencies': filtered_node_frequencies, "node_frequencies": filtered_node_frequencies,
'edge_frequencies': filtered_edge_frequencies, "edge_frequencies": filtered_edge_frequencies,
'element_type_frequencies': relationship_type_frequencies, "element_type_frequencies": relationship_type_frequencies,
'total_interactions': interaction_count, "total_interactions": interaction_count,
'interactions_in_window': interactions_in_window, "interactions_in_window": interactions_in_window,
'time_window_days': time_window.days, "time_window_days": time_window.days,
'last_processed_timestamp': current_time.isoformat(), "last_processed_timestamp": current_time.isoformat(),
'cutoff_timestamp': cutoff_time.isoformat() "cutoff_timestamp": cutoff_time.isoformat(),
} }
async def add_frequency_weights( async def add_frequency_weights(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any]
usage_frequencies: Dict[str, Any]
) -> None: ) -> None:
""" """
Add frequency weights to graph nodes and edges using the graph adapter. Add frequency weights to graph nodes and edges using the graph adapter.
@ -214,8 +225,8 @@ async def add_frequency_weights(
:param graph_adapter: Graph database adapter interface :param graph_adapter: Graph database adapter interface
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency :param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
""" """
node_frequencies = usage_frequencies.get('node_frequencies', {}) node_frequencies = usage_frequencies.get("node_frequencies", {})
edge_frequencies = usage_frequencies.get('edge_frequencies', {}) edge_frequencies = usage_frequencies.get("edge_frequencies", {})
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes") logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
@ -227,15 +238,17 @@ async def add_frequency_weights(
nodes_failed = 0 nodes_failed = 0
# Determine which method to use based on adapter type # Determine which method to use based on adapter type
use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query') use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query")
use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query') use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query")
use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties') use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr(
graph_adapter, "update_node_properties"
)
# Method 1: Neo4j Cypher with SET (creates properties on the fly) # Method 1: Neo4j Cypher with SET (creates properties on the fly)
if use_neo4j_cypher: if use_neo4j_cypher:
try: try:
logger.info("Using Neo4j Cypher SET method") logger.info("Using Neo4j Cypher SET method")
last_updated = usage_frequencies.get('last_processed_timestamp') last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items(): for node_id, frequency in node_frequencies.items():
try: try:
@ -250,10 +263,10 @@ async def add_frequency_weights(
result = await graph_adapter.query( result = await graph_adapter.query(
query, query,
params={ params={
'node_id': node_id, "node_id": node_id,
'frequency': frequency, "frequency": frequency,
'updated_at': last_updated "updated_at": last_updated,
} },
) )
if result and len(result) > 0: if result and len(result) > 0:
@ -273,9 +286,11 @@ async def add_frequency_weights(
use_neo4j_cypher = False use_neo4j_cypher = False
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID) # Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'): elif (
use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node")
):
logger.info("Using Kuzu get_node + add_node method") logger.info("Using Kuzu get_node + add_node method")
last_updated = usage_frequencies.get('last_processed_timestamp') last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items(): for node_id, frequency in node_frequencies.items():
try: try:
@ -284,8 +299,8 @@ async def add_frequency_weights(
if existing_node_dict: if existing_node_dict:
# Update the dict with new properties # Update the dict with new properties
existing_node_dict['frequency_weight'] = frequency existing_node_dict["frequency_weight"] = frequency
existing_node_dict['frequency_updated_at'] = last_updated existing_node_dict["frequency_updated_at"] = last_updated
# Kuzu's add_node likely just takes the dict directly, not a Node object # Kuzu's add_node likely just takes the dict directly, not a Node object
# Try passing the dict directly first # Try passing the dict directly first
@ -298,15 +313,16 @@ async def add_frequency_weights(
try: try:
from cognee.infrastructure.engine import Node from cognee.infrastructure.engine import Node
# Try different Node constructor patterns # Try different Node constructor patterns
try: try:
# Pattern 1: Just properties # Pattern 1: Just properties
node_obj = Node(existing_node_dict) node_obj = Node(existing_node_dict)
except: except Exception:
# Pattern 2: Type and properties # Pattern 2: Type and properties
node_obj = Node( node_obj = Node(
type=existing_node_dict.get('type', 'Unknown'), type=existing_node_dict.get("type", "Unknown"),
**existing_node_dict **existing_node_dict,
) )
await graph_adapter.add_node(node_obj) await graph_adapter.add_node(node_obj)
@ -335,13 +351,15 @@ async def add_frequency_weights(
if node_data: if node_data:
# Tweak the properties dict - add frequency_weight # Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict): if isinstance(node_data, dict):
properties = node_data.get('properties', {}) properties = node_data.get("properties", {})
else: else:
properties = getattr(node_data, 'properties', {}) or {} properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight # Update with frequency weight
properties['frequency_weight'] = frequency properties["frequency_weight"] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter # Write back via adapter
await graph_adapter.update_node_properties(node_id, properties) await graph_adapter.update_node_properties(node_id, properties)
@ -363,13 +381,15 @@ async def add_frequency_weights(
if node_data: if node_data:
# Tweak the properties dict - add frequency_weight # Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict): if isinstance(node_data, dict):
properties = node_data.get('properties', {}) properties = node_data.get("properties", {})
else: else:
properties = getattr(node_data, 'properties', {}) or {} properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight # Update with frequency weight
properties['frequency_weight'] = frequency properties["frequency_weight"] = frequency
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp') properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter # Write back via adapter
await graph_adapter.update_node_properties(node_id, properties) await graph_adapter.update_node_properties(node_id, properties)
@ -385,7 +405,9 @@ async def add_frequency_weights(
# If no method is available # If no method is available
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update: if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
logger.error(f"Adapter {adapter_type} does not support required update methods") logger.error(f"Adapter {adapter_type} does not support required update methods")
logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'") logger.error(
"Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'"
)
return return
# Update edge frequencies # Update edge frequencies
@ -399,22 +421,21 @@ async def add_frequency_weights(
for edge_key, frequency in edge_frequencies.items(): for edge_key, frequency in edge_frequencies.items():
try: try:
# Parse edge key: "relationship_type:source_id:target_id" # Parse edge key: "relationship_type:source_id:target_id"
parts = edge_key.split(':', 2) parts = edge_key.split(":", 2)
if len(parts) == 3: if len(parts) == 3:
relationship_type, source_id, target_id = parts relationship_type, source_id, target_id = parts
# Try to update edge if adapter supports it # Try to update edge if adapter supports it
if hasattr(graph_adapter, 'update_edge_properties'): if hasattr(graph_adapter, "update_edge_properties"):
edge_properties = { edge_properties = {
'frequency_weight': frequency, "frequency_weight": frequency,
'frequency_updated_at': usage_frequencies.get('last_processed_timestamp') "frequency_updated_at": usage_frequencies.get(
"last_processed_timestamp"
),
} }
await graph_adapter.update_edge_properties( await graph_adapter.update_edge_properties(
source_id, source_id, target_id, relationship_type, edge_properties
target_id,
relationship_type,
edge_properties
) )
edges_updated += 1 edges_updated += 1
else: else:
@ -436,15 +457,15 @@ async def add_frequency_weights(
) )
# Store aggregate statistics as metadata if supported # Store aggregate statistics as metadata if supported
if hasattr(graph_adapter, 'set_metadata'): if hasattr(graph_adapter, "set_metadata"):
try: try:
metadata = { metadata = {
'element_type_frequencies': usage_frequencies.get('element_type_frequencies', {}), "element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}),
'total_interactions': usage_frequencies.get('total_interactions', 0), "total_interactions": usage_frequencies.get("total_interactions", 0),
'interactions_in_window': usage_frequencies.get('interactions_in_window', 0), "interactions_in_window": usage_frequencies.get("interactions_in_window", 0),
'last_frequency_update': usage_frequencies.get('last_processed_timestamp') "last_frequency_update": usage_frequencies.get("last_processed_timestamp"),
} }
await graph_adapter.set_metadata('usage_frequency_stats', metadata) await graph_adapter.set_metadata("usage_frequency_stats", metadata)
logger.info("Stored usage frequency statistics as metadata") logger.info("Stored usage frequency statistics as metadata")
except Exception as e: except Exception as e:
logger.warning(f"Could not store usage statistics as metadata: {e}") logger.warning(f"Could not store usage statistics as metadata: {e}")
@ -454,7 +475,7 @@ async def create_usage_frequency_pipeline(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface,
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1, min_interaction_threshold: int = 1,
batch_size: int = 100 batch_size: int = 100,
) -> tuple: ) -> tuple:
""" """
Create memify pipeline entry for usage frequency tracking. Create memify pipeline entry for usage frequency tracking.
@ -486,7 +507,7 @@ async def create_usage_frequency_pipeline(
Task( Task(
extract_usage_frequency, extract_usage_frequency,
time_window=time_window, time_window=time_window,
min_interaction_threshold=min_interaction_threshold min_interaction_threshold=min_interaction_threshold,
) )
] ]
@ -494,7 +515,7 @@ async def create_usage_frequency_pipeline(
Task( Task(
add_frequency_weights, add_frequency_weights,
graph_adapter=graph_adapter, graph_adapter=graph_adapter,
task_config={"batch_size": batch_size} task_config={"batch_size": batch_size},
) )
] ]
@ -505,7 +526,7 @@ async def run_usage_frequency_update(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface,
subgraphs: List[CogneeGraph], subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7), time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1 min_interaction_threshold: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Convenience function to run the complete usage frequency update pipeline. Convenience function to run the complete usage frequency update pipeline.
@ -543,13 +564,12 @@ async def run_usage_frequency_update(
usage_frequencies = await extract_usage_frequency( usage_frequencies = await extract_usage_frequency(
subgraphs=subgraphs, subgraphs=subgraphs,
time_window=time_window, time_window=time_window,
min_interaction_threshold=min_interaction_threshold min_interaction_threshold=min_interaction_threshold,
) )
# Add frequency weights back to the graph # Add frequency weights back to the graph
await add_frequency_weights( await add_frequency_weights(
graph_adapter=graph_adapter, graph_adapter=graph_adapter, usage_frequencies=usage_frequencies
usage_frequencies=usage_frequencies
) )
logger.info("Usage frequency update completed successfully") logger.info("Usage frequency update completed successfully")
@ -566,9 +586,7 @@ async def run_usage_frequency_update(
async def get_most_frequent_elements( async def get_most_frequent_elements(
graph_adapter: GraphDBInterface, graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None
top_n: int = 10,
element_type: Optional[str] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Retrieve the most frequently accessed graph elements. Retrieve the most frequently accessed graph elements.

View file

@ -23,8 +23,9 @@ try:
from cognee.tasks.memify.extract_usage_frequency import ( from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency, extract_usage_frequency,
add_frequency_weights, add_frequency_weights,
run_usage_frequency_update run_usage_frequency_update,
) )
COGNEE_AVAILABLE = True COGNEE_AVAILABLE = True
except ImportError: except ImportError:
COGNEE_AVAILABLE = False COGNEE_AVAILABLE = False
@ -50,10 +51,10 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id=f"interaction_{i}", id=f"interaction_{i}",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'query_text': f'Test query {i}', "query_text": f"Test query {i}",
'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000) "timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000),
} },
) )
graph.add_node(interaction_node) graph.add_node(interaction_node)
@ -62,10 +63,7 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
element_node = Node( element_node = Node(
id=f"element_{i}", id=f"element_{i}",
node_type="DocumentChunk", node_type="DocumentChunk",
attributes={ attributes={"type": "DocumentChunk", "text": f"Element content {i}"},
'type': 'DocumentChunk',
'text': f'Element content {i}'
}
) )
graph.add_node(element_node) graph.add_node(element_node)
@ -78,7 +76,7 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
node1=graph.get_node(f"interaction_{i}"), node1=graph.get_node(f"interaction_{i}"),
node2=graph.get_node(f"element_{element_idx}"), node2=graph.get_node(f"element_{element_idx}"),
edge_type="used_graph_element_to_answer", edge_type="used_graph_element_to_answer",
attributes={'relationship_type': 'used_graph_element_to_answer'} attributes={"relationship_type": "used_graph_element_to_answer"},
) )
graph.add_edge(edge) graph.add_edge(edge)
@ -89,15 +87,13 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
graph = self.create_mock_graph(num_interactions=3, num_elements=5) graph = self.create_mock_graph(num_interactions=3, num_elements=5)
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
time_window=timedelta(days=7),
min_interaction_threshold=1
) )
self.assertIn('node_frequencies', result) self.assertIn("node_frequencies", result)
self.assertIn('total_interactions', result) self.assertIn("total_interactions", result)
self.assertEqual(result['total_interactions'], 3) self.assertEqual(result["total_interactions"], 3)
self.assertGreater(len(result['node_frequencies']), 0) self.assertGreater(len(result["node_frequencies"]), 0)
async def test_time_window_filtering(self): async def test_time_window_filtering(self):
"""Test that time window correctly filters old interactions.""" """Test that time window correctly filters old interactions."""
@ -110,9 +106,9 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id="recent_interaction", id="recent_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(current_time.timestamp() * 1000) "timestamp": int(current_time.timestamp() * 1000),
} },
) )
graph.add_node(recent_node) graph.add_node(recent_node)
@ -121,38 +117,44 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id="old_interaction", id="old_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000) "timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000),
} },
) )
graph.add_node(old_node) graph.add_node(old_node)
# Add element # Add element
element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) element = Node(
id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}
)
graph.add_node(element) graph.add_node(element)
# Add edges # Add edges
graph.add_edge(Edge( graph.add_edge(
node1=recent_node, node2=element, Edge(
edge_type="used_graph_element_to_answer", node1=recent_node,
attributes={'relationship_type': 'used_graph_element_to_answer'} node2=element,
)) edge_type="used_graph_element_to_answer",
graph.add_edge(Edge( attributes={"relationship_type": "used_graph_element_to_answer"},
node1=old_node, node2=element, )
edge_type="used_graph_element_to_answer", )
attributes={'relationship_type': 'used_graph_element_to_answer'} graph.add_edge(
)) Edge(
node1=old_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
# Extract with 7-day window # Extract with 7-day window
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
time_window=timedelta(days=7),
min_interaction_threshold=1
) )
# Should only count recent interaction # Should only count recent interaction
self.assertEqual(result['interactions_in_window'], 1) self.assertEqual(result["interactions_in_window"], 1)
self.assertEqual(result['total_interactions'], 2) self.assertEqual(result["total_interactions"], 2)
async def test_threshold_filtering(self): async def test_threshold_filtering(self):
"""Test that minimum threshold filters low-frequency nodes.""" """Test that minimum threshold filters low-frequency nodes."""
@ -160,13 +162,11 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
# Extract with threshold of 3 # Extract with threshold of 3
result = await extract_usage_frequency( result = await extract_usage_frequency(
subgraphs=[graph], subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3
time_window=timedelta(days=7),
min_interaction_threshold=3
) )
# Only nodes with 3+ accesses should be included # Only nodes with 3+ accesses should be included
for node_id, freq in result['node_frequencies'].items(): for node_id, freq in result["node_frequencies"].items():
self.assertGreaterEqual(freq, 3) self.assertGreaterEqual(freq, 3)
async def test_element_type_tracking(self): async def test_element_type_tracking(self):
@ -178,49 +178,46 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id="interaction_1", id="interaction_1",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(datetime.now().timestamp() * 1000) "timestamp": int(datetime.now().timestamp() * 1000),
} },
) )
graph.add_node(interaction) graph.add_node(interaction)
# Create elements of different types # Create elements of different types
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'}) chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"})
entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'}) entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"})
graph.add_node(chunk) graph.add_node(chunk)
graph.add_node(entity) graph.add_node(entity)
# Add edges # Add edges
for element in [chunk, entity]: for element in [chunk, entity]:
graph.add_edge(Edge( graph.add_edge(
node1=interaction, node2=element, Edge(
edge_type="used_graph_element_to_answer", node1=interaction,
attributes={'relationship_type': 'used_graph_element_to_answer'} node2=element,
)) edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
result = await extract_usage_frequency( result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
subgraphs=[graph],
time_window=timedelta(days=7)
)
# Check element types were tracked # Check element types were tracked
self.assertIn('element_type_frequencies', result) self.assertIn("element_type_frequencies", result)
types = result['element_type_frequencies'] types = result["element_type_frequencies"]
self.assertIn('DocumentChunk', types) self.assertIn("DocumentChunk", types)
self.assertIn('Entity', types) self.assertIn("Entity", types)
async def test_empty_graph(self): async def test_empty_graph(self):
"""Test handling of empty graph.""" """Test handling of empty graph."""
graph = CogneeGraph() graph = CogneeGraph()
result = await extract_usage_frequency( result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
subgraphs=[graph],
time_window=timedelta(days=7)
)
self.assertEqual(result['total_interactions'], 0) self.assertEqual(result["total_interactions"], 0)
self.assertEqual(len(result['node_frequencies']), 0) self.assertEqual(len(result["node_frequencies"]), 0)
async def test_no_interactions_in_window(self): async def test_no_interactions_in_window(self):
"""Test handling when all interactions are outside time window.""" """Test handling when all interactions are outside time window."""
@ -232,19 +229,16 @@ class TestUsageFrequencyExtraction(unittest.TestCase):
id="old_interaction", id="old_interaction",
node_type="CogneeUserInteraction", node_type="CogneeUserInteraction",
attributes={ attributes={
'type': 'CogneeUserInteraction', "type": "CogneeUserInteraction",
'timestamp': int(old_time.timestamp() * 1000) "timestamp": int(old_time.timestamp() * 1000),
} },
) )
graph.add_node(old_interaction) graph.add_node(old_interaction)
result = await extract_usage_frequency( result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
subgraphs=[graph],
time_window=timedelta(days=7)
)
self.assertEqual(result['interactions_in_window'], 0) self.assertEqual(result["interactions_in_window"], 0)
self.assertEqual(result['total_interactions'], 1) self.assertEqual(result["total_interactions"], 1)
class TestIntegration(unittest.TestCase): class TestIntegration(unittest.TestCase):
@ -266,6 +260,7 @@ class TestIntegration(unittest.TestCase):
# Test Runner # Test Runner
# ============================================================================ # ============================================================================
def run_async_test(test_func): def run_async_test(test_func):
"""Helper to run async test functions.""" """Helper to run async test functions."""
asyncio.run(test_func()) asyncio.run(test_func())

View file

@ -39,6 +39,7 @@ load_dotenv()
# STEP 1: Setup and Configuration # STEP 1: Setup and Configuration
# ============================================================================ # ============================================================================
async def setup_knowledge_base(): async def setup_knowledge_base():
""" """
Create a fresh knowledge base with sample content. Create a fresh knowledge base with sample content.
@ -104,6 +105,7 @@ async def setup_knowledge_base():
# STEP 2: Simulate User Searches with Interaction Tracking # STEP 2: Simulate User Searches with Interaction Tracking
# ============================================================================ # ============================================================================
async def simulate_user_searches(queries: List[str]): async def simulate_user_searches(queries: List[str]):
""" """
Simulate users searching the knowledge base. Simulate users searching the knowledge base.
@ -131,7 +133,7 @@ async def simulate_user_searches(queries: List[str]):
query_type=SearchType.GRAPH_COMPLETION, query_type=SearchType.GRAPH_COMPLETION,
query_text=query, query_text=query,
save_interaction=True, # ← THIS IS CRITICAL! save_interaction=True, # ← THIS IS CRITICAL!
top_k=5 top_k=5,
) )
successful_searches += 1 successful_searches += 1
@ -152,9 +154,9 @@ async def simulate_user_searches(queries: List[str]):
# STEP 3: Extract and Apply Usage Frequencies # STEP 3: Extract and Apply Usage Frequencies
# ============================================================================ # ============================================================================
async def extract_and_apply_frequencies( async def extract_and_apply_frequencies(
time_window_days: int = 7, time_window_days: int = 7, min_threshold: int = 1
min_threshold: int = 1
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Extract usage frequencies from interactions and apply them to the graph. Extract usage frequencies from interactions and apply them to the graph.
@ -184,8 +186,14 @@ async def extract_and_apply_frequencies(
await graph.project_graph_from_db( await graph.project_graph_from_db(
adapter=graph_engine, adapter=graph_engine,
node_properties_to_project=[ node_properties_to_project=[
"type", "node_type", "timestamp", "created_at", "type",
"text", "name", "query_text", "frequency_weight" "node_type",
"timestamp",
"created_at",
"text",
"name",
"query_text",
"frequency_weight",
], ],
edge_properties_to_project=["relationship_type", "timestamp"], edge_properties_to_project=["relationship_type", "timestamp"],
directed=True, directed=True,
@ -195,9 +203,10 @@ async def extract_and_apply_frequencies(
# Count interaction nodes # Count interaction nodes
interaction_nodes = [ interaction_nodes = [
n for n in graph.nodes.values() n
if n.attributes.get('type') == 'CogneeUserInteraction' or for n in graph.nodes.values()
n.attributes.get('node_type') == 'CogneeUserInteraction' if n.attributes.get("type") == "CogneeUserInteraction"
or n.attributes.get("node_type") == "CogneeUserInteraction"
] ]
print(f"✓ Found {len(interaction_nodes)} interaction nodes") print(f"✓ Found {len(interaction_nodes)} interaction nodes")
@ -207,11 +216,13 @@ async def extract_and_apply_frequencies(
graph_adapter=graph_engine, graph_adapter=graph_engine,
subgraphs=[graph], subgraphs=[graph],
time_window=timedelta(days=time_window_days), time_window=timedelta(days=time_window_days),
min_interaction_threshold=min_threshold min_interaction_threshold=min_threshold,
) )
print(f"\n✓ Frequency extraction complete!") print("\n✓ Frequency extraction complete!")
print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}") print(
f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}"
)
print(f" - Nodes weighted: {len(stats['node_frequencies'])}") print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}") print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
@ -224,6 +235,7 @@ async def extract_and_apply_frequencies(
# STEP 4: Analyze and Display Results # STEP 4: Analyze and Display Results
# ============================================================================ # ============================================================================
async def analyze_results(stats: Dict[str, Any]): async def analyze_results(stats: Dict[str, Any]):
""" """
Analyze and display the frequency tracking results. Analyze and display the frequency tracking results.
@ -241,15 +253,11 @@ async def analyze_results(stats: Dict[str, Any]):
print("=" * 80) print("=" * 80)
# Display top nodes by frequency # Display top nodes by frequency
if stats['node_frequencies']: if stats["node_frequencies"]:
print("\n📊 Top 10 Most Frequently Accessed Elements:") print("\n📊 Top 10 Most Frequently Accessed Elements:")
print("-" * 80) print("-" * 80)
sorted_nodes = sorted( sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True)
stats['node_frequencies'].items(),
key=lambda x: x[1],
reverse=True
)
# Get graph to display node details # Get graph to display node details
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
@ -264,8 +272,8 @@ async def analyze_results(stats: Dict[str, Any]):
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1): for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
node = graph.get_node(node_id) node = graph.get_node(node_id)
if node: if node:
node_type = node.attributes.get('type', 'Unknown') node_type = node.attributes.get("type", "Unknown")
text = node.attributes.get('text') or node.attributes.get('name') or '' text = node.attributes.get("text") or node.attributes.get("name") or ""
text_preview = text[:60] + "..." if len(text) > 60 else text text_preview = text[:60] + "..." if len(text) > 60 else text
print(f"\n{i}. Frequency: {frequency} accesses") print(f"\n{i}. Frequency: {frequency} accesses")
@ -276,10 +284,10 @@ async def analyze_results(stats: Dict[str, Any]):
print(f" Node ID: {node_id[:50]}...") print(f" Node ID: {node_id[:50]}...")
# Display element type distribution # Display element type distribution
if stats.get('element_type_frequencies'): if stats.get("element_type_frequencies"):
print("\n\n📈 Element Type Distribution:") print("\n\n📈 Element Type Distribution:")
print("-" * 80) print("-" * 80)
type_dist = stats['element_type_frequencies'] type_dist = stats["element_type_frequencies"]
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True): for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
print(f" {elem_type}: {count} accesses") print(f" {elem_type}: {count} accesses")
@ -290,7 +298,7 @@ async def analyze_results(stats: Dict[str, Any]):
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
adapter_type = type(graph_engine).__name__ adapter_type = type(graph_engine).__name__
if adapter_type == 'Neo4jAdapter': if adapter_type == "Neo4jAdapter":
try: try:
result = await graph_engine.query(""" result = await graph_engine.query("""
MATCH (n) MATCH (n)
@ -298,7 +306,7 @@ async def analyze_results(stats: Dict[str, Any]):
RETURN count(n) as weighted_count RETURN count(n) as weighted_count
""") """)
count = result[0]['weighted_count'] if result else 0 count = result[0]["weighted_count"] if result else 0
if count > 0: if count > 0:
print(f"{count} nodes have frequency_weight in Neo4j database") print(f"{count} nodes have frequency_weight in Neo4j database")
@ -328,6 +336,7 @@ async def analyze_results(stats: Dict[str, Any]):
# STEP 5: Demonstrate Usage in Retrieval # STEP 5: Demonstrate Usage in Retrieval
# ============================================================================ # ============================================================================
async def demonstrate_retrieval_usage(): async def demonstrate_retrieval_usage():
""" """
Demonstrate how frequency weights can be used in retrieval. Demonstrate how frequency weights can be used in retrieval.
@ -379,6 +388,7 @@ async def demonstrate_retrieval_usage():
# MAIN: Run Complete Example # MAIN: Run Complete Example
# ============================================================================ # ============================================================================
async def main(): async def main():
""" """
Run the complete end-to-end usage frequency tracking example. Run the complete end-to-end usage frequency tracking example.
@ -398,7 +408,7 @@ async def main():
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}") print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
# Verify LLM key is set # Verify LLM key is set
if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here': if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here":
print("\n⚠ WARNING: LLM_API_KEY not set in .env file") print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
print(" Set your API key to run searches") print(" Set your API key to run searches")
return return
@ -430,10 +440,7 @@ async def main():
return return
# Step 3: Extract frequencies # Step 3: Extract frequencies
stats = await extract_and_apply_frequencies( stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1)
time_window_days=7,
min_threshold=1
)
# Step 4: Analyze results # Step 4: Analyze results
await analyze_results(stats) await analyze_results(stats)
@ -451,7 +458,7 @@ async def main():
print("\n") print("\n")
print("Summary:") print("Summary:")
print(f" ✓ Documents added: 4") print(" ✓ Documents added: 4")
print(f" ✓ Searches performed: {successful_searches}") print(f" ✓ Searches performed: {successful_searches}")
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}") print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}") print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
@ -467,6 +474,7 @@ async def main():
except Exception as e: except Exception as e:
print(f"\n✗ Example failed: {e}") print(f"\n✗ Example failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()