feat: feat to support issue #1458 frequency weights addition for neo4j backend
This commit is contained in:
parent
e0c7e68dd6
commit
d09b6df241
3 changed files with 926 additions and 772 deletions
|
|
@ -1,3 +1,4 @@
|
|||
# cognee/tasks/memify/extract_usage_frequency.py
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -51,10 +52,72 @@ async def extract_usage_frequency(
|
|||
|
||||
if node_type == 'CogneeUserInteraction':
|
||||
# Parse and validate timestamp
|
||||
timestamp_str = node.attributes.get('timestamp') or node.attributes.get('created_at')
|
||||
if timestamp_str:
|
||||
timestamp_value = node.attributes.get('timestamp') or node.attributes.get('created_at')
|
||||
if timestamp_value is not None:
|
||||
try:
|
||||
interaction_time = datetime.fromisoformat(timestamp_str)
|
||||
# Handle various timestamp formats
|
||||
interaction_time = None
|
||||
|
||||
if isinstance(timestamp_value, datetime):
|
||||
# Already a Python datetime
|
||||
interaction_time = timestamp_value
|
||||
elif isinstance(timestamp_value, (int, float)):
|
||||
# Unix timestamp (assume milliseconds if > 10 digits)
|
||||
if timestamp_value > 10000000000:
|
||||
# Milliseconds since epoch
|
||||
interaction_time = datetime.fromtimestamp(timestamp_value / 1000.0)
|
||||
else:
|
||||
# Seconds since epoch
|
||||
interaction_time = datetime.fromtimestamp(timestamp_value)
|
||||
elif isinstance(timestamp_value, str):
|
||||
# Try different string formats
|
||||
if timestamp_value.isdigit():
|
||||
# Numeric string - treat as Unix timestamp
|
||||
ts_int = int(timestamp_value)
|
||||
if ts_int > 10000000000:
|
||||
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
|
||||
else:
|
||||
interaction_time = datetime.fromtimestamp(ts_int)
|
||||
else:
|
||||
# ISO format string
|
||||
interaction_time = datetime.fromisoformat(timestamp_value)
|
||||
elif hasattr(timestamp_value, 'to_native'):
|
||||
# Neo4j datetime object - convert to Python datetime
|
||||
interaction_time = timestamp_value.to_native()
|
||||
elif hasattr(timestamp_value, 'year') and hasattr(timestamp_value, 'month'):
|
||||
# Datetime-like object - extract components
|
||||
try:
|
||||
interaction_time = datetime(
|
||||
year=timestamp_value.year,
|
||||
month=timestamp_value.month,
|
||||
day=timestamp_value.day,
|
||||
hour=getattr(timestamp_value, 'hour', 0),
|
||||
minute=getattr(timestamp_value, 'minute', 0),
|
||||
second=getattr(timestamp_value, 'second', 0),
|
||||
microsecond=getattr(timestamp_value, 'microsecond', 0)
|
||||
)
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
if interaction_time is None:
|
||||
# Last resort: try converting to string and parsing
|
||||
str_value = str(timestamp_value)
|
||||
if str_value.isdigit():
|
||||
ts_int = int(str_value)
|
||||
if ts_int > 10000000000:
|
||||
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
|
||||
else:
|
||||
interaction_time = datetime.fromtimestamp(ts_int)
|
||||
else:
|
||||
interaction_time = datetime.fromisoformat(str_value)
|
||||
|
||||
if interaction_time is None:
|
||||
raise ValueError(f"Could not parse timestamp: {timestamp_value}")
|
||||
|
||||
# Make sure it's timezone-naive for comparison
|
||||
if interaction_time.tzinfo is not None:
|
||||
interaction_time = interaction_time.replace(tzinfo=None)
|
||||
|
||||
interaction_nodes[node_id] = {
|
||||
'node': node,
|
||||
'timestamp': interaction_time,
|
||||
|
|
@ -63,8 +126,9 @@ async def extract_usage_frequency(
|
|||
interaction_count += 1
|
||||
if interaction_time >= cutoff_time:
|
||||
interactions_in_window += 1
|
||||
except (ValueError, TypeError) as e:
|
||||
except (ValueError, TypeError, AttributeError, OSError) as e:
|
||||
logger.warning(f"Failed to parse timestamp for interaction node {node_id}: {e}")
|
||||
logger.debug(f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}")
|
||||
|
||||
# Process edges to find graph elements used in interactions
|
||||
for edge in subgraph.edges:
|
||||
|
|
@ -141,7 +205,7 @@ async def add_frequency_weights(
|
|||
"""
|
||||
Add frequency weights to graph nodes and edges using the graph adapter.
|
||||
|
||||
Uses the "get → tweak dict → update" contract consistent with graph adapters.
|
||||
Uses direct Cypher queries for Neo4j adapter compatibility.
|
||||
Writes frequency_weight properties back to the graph for use in:
|
||||
- Ranking frequently referenced entities higher during retrieval
|
||||
- Adjusting scoring for completion strategies
|
||||
|
|
@ -155,43 +219,174 @@ async def add_frequency_weights(
|
|||
|
||||
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
|
||||
|
||||
# Update node frequencies using get → tweak → update pattern
|
||||
# Check adapter type and use appropriate method
|
||||
adapter_type = type(graph_adapter).__name__
|
||||
logger.info(f"Using adapter: {adapter_type}")
|
||||
|
||||
nodes_updated = 0
|
||||
nodes_failed = 0
|
||||
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
# Determine which method to use based on adapter type
|
||||
use_neo4j_cypher = adapter_type == 'Neo4jAdapter' and hasattr(graph_adapter, 'query')
|
||||
use_kuzu_query = adapter_type == 'KuzuAdapter' and hasattr(graph_adapter, 'query')
|
||||
use_get_update = hasattr(graph_adapter, 'get_node_by_id') and hasattr(graph_adapter, 'update_node_properties')
|
||||
|
||||
# Method 1: Neo4j Cypher with SET (creates properties on the fly)
|
||||
if use_neo4j_cypher:
|
||||
try:
|
||||
# Get current node data
|
||||
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||
logger.info("Using Neo4j Cypher SET method")
|
||||
last_updated = usage_frequencies.get('last_processed_timestamp')
|
||||
|
||||
if node_data:
|
||||
# Tweak the properties dict - add frequency_weight
|
||||
if isinstance(node_data, dict):
|
||||
properties = node_data.get('properties', {})
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.id = $node_id
|
||||
SET n.frequency_weight = $frequency,
|
||||
n.frequency_updated_at = $updated_at
|
||||
RETURN n.id as id
|
||||
"""
|
||||
|
||||
result = await graph_adapter.query(
|
||||
query,
|
||||
params={
|
||||
'node_id': node_id,
|
||||
'frequency': frequency,
|
||||
'updated_at': last_updated
|
||||
}
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found or not updated")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j Cypher update failed: {e}")
|
||||
use_neo4j_cypher = False
|
||||
|
||||
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
|
||||
elif use_kuzu_query and hasattr(graph_adapter, 'get_node') and hasattr(graph_adapter, 'add_node'):
|
||||
logger.info("Using Kuzu get_node + add_node method")
|
||||
last_updated = usage_frequencies.get('last_processed_timestamp')
|
||||
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get the existing node (returns a dict)
|
||||
existing_node_dict = await graph_adapter.get_node(node_id)
|
||||
|
||||
if existing_node_dict:
|
||||
# Update the dict with new properties
|
||||
existing_node_dict['frequency_weight'] = frequency
|
||||
existing_node_dict['frequency_updated_at'] = last_updated
|
||||
|
||||
# Kuzu's add_node likely just takes the dict directly, not a Node object
|
||||
# Try passing the dict directly first
|
||||
try:
|
||||
await graph_adapter.add_node(existing_node_dict)
|
||||
nodes_updated += 1
|
||||
except Exception as dict_error:
|
||||
# If dict doesn't work, try creating a Node object
|
||||
logger.debug(f"Dict add failed, trying Node object: {dict_error}")
|
||||
|
||||
try:
|
||||
from cognee.infrastructure.engine import Node
|
||||
# Try different Node constructor patterns
|
||||
try:
|
||||
# Pattern 1: Just properties
|
||||
node_obj = Node(existing_node_dict)
|
||||
except:
|
||||
# Pattern 2: Type and properties
|
||||
node_obj = Node(
|
||||
type=existing_node_dict.get('type', 'Unknown'),
|
||||
**existing_node_dict
|
||||
)
|
||||
|
||||
await graph_adapter.add_node(node_obj)
|
||||
nodes_updated += 1
|
||||
except Exception as node_error:
|
||||
logger.error(f"Both dict and Node object failed: {node_error}")
|
||||
nodes_failed += 1
|
||||
else:
|
||||
# Handle case where node_data might be a node object
|
||||
properties = getattr(node_data, 'properties', {}) or {}
|
||||
|
||||
# Update with frequency weight
|
||||
properties['frequency_weight'] = frequency
|
||||
|
||||
# Also store when this was last updated
|
||||
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
|
||||
|
||||
# Write back via adapter
|
||||
await graph_adapter.update_node_properties(node_id, properties)
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
|
||||
logger.info(
|
||||
f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed"
|
||||
)
|
||||
# Method 3: Generic get_node_by_id + update_node_properties
|
||||
elif use_get_update:
|
||||
logger.info("Using get/update method for adapter")
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get current node data
|
||||
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||
|
||||
if node_data:
|
||||
# Tweak the properties dict - add frequency_weight
|
||||
if isinstance(node_data, dict):
|
||||
properties = node_data.get('properties', {})
|
||||
else:
|
||||
properties = getattr(node_data, 'properties', {}) or {}
|
||||
|
||||
# Update with frequency weight
|
||||
properties['frequency_weight'] = frequency
|
||||
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
|
||||
|
||||
# Write back via adapter
|
||||
await graph_adapter.update_node_properties(node_id, properties)
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get current node data
|
||||
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||
|
||||
if node_data:
|
||||
# Tweak the properties dict - add frequency_weight
|
||||
if isinstance(node_data, dict):
|
||||
properties = node_data.get('properties', {})
|
||||
else:
|
||||
properties = getattr(node_data, 'properties', {}) or {}
|
||||
|
||||
# Update with frequency weight
|
||||
properties['frequency_weight'] = frequency
|
||||
properties['frequency_updated_at'] = usage_frequencies.get('last_processed_timestamp')
|
||||
|
||||
# Write back via adapter
|
||||
await graph_adapter.update_node_properties(node_id, properties)
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
# If no method is available
|
||||
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
|
||||
logger.error(f"Adapter {adapter_type} does not support required update methods")
|
||||
logger.error("Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'")
|
||||
return
|
||||
|
||||
# Update edge frequencies
|
||||
# Note: Edge property updates are backend-specific
|
||||
|
|
|
|||
|
|
@ -1,503 +1,313 @@
|
|||
# cognee/tests/test_usage_frequency.py
|
||||
"""
|
||||
Test suite for usage frequency tracking functionality.
|
||||
Test Suite: Usage Frequency Tracking
|
||||
|
||||
Tests cover:
|
||||
- Frequency extraction from CogneeUserInteraction nodes
|
||||
- Time window filtering
|
||||
- Frequency weight application to graph
|
||||
- Edge cases and error handling
|
||||
Comprehensive tests for the usage frequency tracking implementation.
|
||||
Tests cover extraction logic, adapter integration, edge cases, and end-to-end workflows.
|
||||
|
||||
Run with:
|
||||
pytest test_usage_frequency_comprehensive.py -v
|
||||
|
||||
Or without pytest:
|
||||
python test_usage_frequency_comprehensive.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict, Any
|
||||
from typing import List, Dict
|
||||
|
||||
from cognee.tasks.memify.extract_usage_frequency import (
|
||||
extract_usage_frequency,
|
||||
add_frequency_weights,
|
||||
create_usage_frequency_pipeline,
|
||||
run_usage_frequency_update,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
|
||||
|
||||
def create_mock_node(node_id: str, attributes: Dict[str, Any]) -> Node:
|
||||
"""Helper to create mock Node objects."""
|
||||
node = Node(node_id, attributes)
|
||||
return node
|
||||
|
||||
|
||||
def create_mock_edge(node1: Node, node2: Node, relationship_type: str, attributes: Dict[str, Any] = None) -> Edge:
|
||||
"""Helper to create mock Edge objects."""
|
||||
edge_attrs = attributes or {}
|
||||
edge_attrs['relationship_type'] = relationship_type
|
||||
edge = Edge(node1, node2, attributes=edge_attrs, directed=True)
|
||||
return edge
|
||||
|
||||
|
||||
def create_interaction_graph(
|
||||
interaction_count: int = 3,
|
||||
target_nodes: list = None,
|
||||
time_offset_hours: int = 0
|
||||
) -> CogneeGraph:
|
||||
"""
|
||||
Create a mock CogneeGraph with interaction nodes.
|
||||
|
||||
:param interaction_count: Number of interactions to create
|
||||
:param target_nodes: List of target node IDs to reference
|
||||
:param time_offset_hours: Hours to offset timestamp (negative = past)
|
||||
:return: CogneeGraph with mocked interaction data
|
||||
"""
|
||||
graph = CogneeGraph(directed=True)
|
||||
|
||||
if target_nodes is None:
|
||||
target_nodes = ['node1', 'node2', 'node3']
|
||||
|
||||
# Create some target graph element nodes
|
||||
element_nodes = {}
|
||||
for i, node_id in enumerate(target_nodes):
|
||||
element_node = create_mock_node(
|
||||
node_id,
|
||||
{
|
||||
'type': 'DocumentChunk',
|
||||
'text': f'This is content for {node_id}',
|
||||
'name': f'Element {i+1}'
|
||||
}
|
||||
)
|
||||
graph.add_node(element_node)
|
||||
element_nodes[node_id] = element_node
|
||||
|
||||
# Create interaction nodes and edges
|
||||
timestamp = datetime.now() + timedelta(hours=time_offset_hours)
|
||||
|
||||
for i in range(interaction_count):
|
||||
# Create interaction node
|
||||
interaction_id = f'interaction_{i}'
|
||||
target_id = target_nodes[i % len(target_nodes)]
|
||||
|
||||
interaction_node = create_mock_node(
|
||||
interaction_id,
|
||||
{
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'query_text': f'Sample query {i}',
|
||||
'target_node_id': target_id # Also store in attributes for completeness
|
||||
}
|
||||
)
|
||||
graph.add_node(interaction_node)
|
||||
|
||||
# Create edge from interaction to target element
|
||||
target_element = element_nodes[target_id]
|
||||
edge = create_mock_edge(
|
||||
interaction_node,
|
||||
target_element,
|
||||
'used_graph_element_to_answer',
|
||||
{'timestamp': timestamp.isoformat()}
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_basic():
|
||||
"""Test basic frequency extraction with simple interaction data."""
|
||||
# Create mock graph with 3 interactions
|
||||
# node1 referenced twice, node2 referenced once
|
||||
mock_graph = create_interaction_graph(
|
||||
interaction_count=3,
|
||||
target_nodes=['node1', 'node1', 'node2']
|
||||
)
|
||||
|
||||
# Extract frequencies
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[mock_graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert 'node_frequencies' in result
|
||||
assert 'edge_frequencies' in result
|
||||
assert result['node_frequencies']['node1'] == 2
|
||||
assert result['node_frequencies']['node2'] == 1
|
||||
assert result['total_interactions'] == 3
|
||||
assert result['interactions_in_window'] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_time_window():
|
||||
"""Test that time window filtering works correctly."""
|
||||
# Create two graphs: one recent, one old
|
||||
recent_graph = create_interaction_graph(
|
||||
interaction_count=2,
|
||||
target_nodes=['node1', 'node2'],
|
||||
time_offset_hours=-1 # 1 hour ago
|
||||
)
|
||||
|
||||
old_graph = create_interaction_graph(
|
||||
interaction_count=2,
|
||||
target_nodes=['node3', 'node4'],
|
||||
time_offset_hours=-200 # 200 hours ago (> 7 days)
|
||||
)
|
||||
|
||||
# Extract with 7-day window
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[recent_graph, old_graph],
|
||||
time_window=timedelta(days=7),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Only recent interactions should be counted
|
||||
assert result['total_interactions'] == 4 # All interactions found
|
||||
assert result['interactions_in_window'] == 2 # Only recent ones counted
|
||||
assert 'node1' in result['node_frequencies']
|
||||
assert 'node2' in result['node_frequencies']
|
||||
assert 'node3' not in result['node_frequencies'] # Too old
|
||||
assert 'node4' not in result['node_frequencies'] # Too old
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_threshold():
|
||||
"""Test minimum interaction threshold filtering."""
|
||||
# Create graph where node1 has 3 interactions, node2 has 1
|
||||
mock_graph = create_interaction_graph(
|
||||
interaction_count=4,
|
||||
target_nodes=['node1', 'node1', 'node1', 'node2']
|
||||
)
|
||||
|
||||
# Extract with threshold of 2
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[mock_graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=2
|
||||
)
|
||||
|
||||
# Only node1 should be in results (3 >= 2)
|
||||
assert 'node1' in result['node_frequencies']
|
||||
assert result['node_frequencies']['node1'] == 3
|
||||
assert 'node2' not in result['node_frequencies'] # Below threshold
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_multiple_graphs():
|
||||
"""Test extraction across multiple subgraphs."""
|
||||
graph1 = create_interaction_graph(
|
||||
interaction_count=2,
|
||||
target_nodes=['node1', 'node2']
|
||||
)
|
||||
|
||||
graph2 = create_interaction_graph(
|
||||
interaction_count=2,
|
||||
target_nodes=['node1', 'node3']
|
||||
)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph1, graph2],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# node1 should have frequency of 2 (once from each graph)
|
||||
assert result['node_frequencies']['node1'] == 2
|
||||
assert result['node_frequencies']['node2'] == 1
|
||||
assert result['node_frequencies']['node3'] == 1
|
||||
assert result['total_interactions'] == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_empty_graph():
|
||||
"""Test handling of empty graphs."""
|
||||
empty_graph = CogneeGraph(directed=True)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[empty_graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
assert result['node_frequencies'] == {}
|
||||
assert result['edge_frequencies'] == {}
|
||||
assert result['total_interactions'] == 0
|
||||
assert result['interactions_in_window'] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_invalid_timestamps():
|
||||
"""Test handling of invalid timestamp formats."""
|
||||
graph = CogneeGraph(directed=True)
|
||||
|
||||
# Create interaction with invalid timestamp
|
||||
bad_interaction = create_mock_node(
|
||||
'bad_interaction',
|
||||
{
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': 'not-a-valid-timestamp',
|
||||
'target_node_id': 'node1'
|
||||
}
|
||||
)
|
||||
graph.add_node(bad_interaction)
|
||||
|
||||
# Should not crash, just skip invalid interaction
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
assert result['total_interactions'] == 0 # Invalid interaction not counted
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_element_type_tracking():
|
||||
"""Test that element type frequencies are tracked."""
|
||||
graph = CogneeGraph(directed=True)
|
||||
|
||||
# Create different types of target nodes
|
||||
chunk_node = create_mock_node('chunk1', {'type': 'DocumentChunk', 'text': 'content'})
|
||||
entity_node = create_mock_node('entity1', {'type': 'Entity', 'name': 'Alice'})
|
||||
|
||||
graph.add_node(chunk_node)
|
||||
graph.add_node(entity_node)
|
||||
|
||||
# Create interactions pointing to each
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
for i, target in enumerate([chunk_node, chunk_node, entity_node]):
|
||||
interaction = create_mock_node(
|
||||
f'interaction_{i}',
|
||||
{'type': 'CogneeUserInteraction', 'timestamp': timestamp}
|
||||
)
|
||||
graph.add_node(interaction)
|
||||
|
||||
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
|
||||
graph.add_edge(edge)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Check element type frequencies
|
||||
assert 'element_type_frequencies' in result
|
||||
assert result['element_type_frequencies']['DocumentChunk'] == 2
|
||||
assert result['element_type_frequencies']['Entity'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_frequency_weights():
|
||||
"""Test adding frequency weights to graph via adapter."""
|
||||
# Mock graph adapter
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.get_node_by_id = AsyncMock(return_value={
|
||||
'id': 'node1',
|
||||
'properties': {'type': 'DocumentChunk', 'text': 'content'}
|
||||
})
|
||||
mock_adapter.update_node_properties = AsyncMock()
|
||||
|
||||
# Mock usage frequencies
|
||||
usage_frequencies = {
|
||||
'node_frequencies': {'node1': 5, 'node2': 3},
|
||||
'edge_frequencies': {},
|
||||
'last_processed_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add weights
|
||||
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||
|
||||
# Verify adapter methods were called
|
||||
assert mock_adapter.get_node_by_id.call_count == 2
|
||||
assert mock_adapter.update_node_properties.call_count == 2
|
||||
|
||||
# Verify the properties passed to update include frequency_weight
|
||||
calls = mock_adapter.update_node_properties.call_args_list
|
||||
properties_updated = calls[0][0][1] # Second argument of first call
|
||||
assert 'frequency_weight' in properties_updated
|
||||
assert properties_updated['frequency_weight'] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_frequency_weights_node_not_found():
|
||||
"""Test handling when node is not found in graph."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.get_node_by_id = AsyncMock(return_value=None) # Node not found
|
||||
mock_adapter.update_node_properties = AsyncMock()
|
||||
|
||||
usage_frequencies = {
|
||||
'node_frequencies': {'nonexistent_node': 5},
|
||||
'edge_frequencies': {},
|
||||
'last_processed_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Should not crash
|
||||
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||
|
||||
# Update should not be called since node wasn't found
|
||||
assert mock_adapter.update_node_properties.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_frequency_weights_with_metadata_support():
|
||||
"""Test that metadata is stored when adapter supports it."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
|
||||
mock_adapter.update_node_properties = AsyncMock()
|
||||
mock_adapter.set_metadata = AsyncMock() # Adapter supports metadata
|
||||
|
||||
usage_frequencies = {
|
||||
'node_frequencies': {'node1': 5},
|
||||
'edge_frequencies': {},
|
||||
'element_type_frequencies': {'DocumentChunk': 5},
|
||||
'total_interactions': 10,
|
||||
'interactions_in_window': 8,
|
||||
'last_processed_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
await add_frequency_weights(mock_adapter, usage_frequencies)
|
||||
|
||||
# Verify metadata was stored
|
||||
mock_adapter.set_metadata.assert_called_once()
|
||||
metadata_key, metadata_value = mock_adapter.set_metadata.call_args[0]
|
||||
assert metadata_key == 'usage_frequency_stats'
|
||||
assert 'total_interactions' in metadata_value
|
||||
assert metadata_value['total_interactions'] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_usage_frequency_pipeline():
|
||||
"""Test pipeline creation returns correct task structure."""
|
||||
mock_adapter = AsyncMock()
|
||||
|
||||
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||
graph_adapter=mock_adapter,
|
||||
time_window=timedelta(days=7),
|
||||
min_interaction_threshold=2,
|
||||
batch_size=50
|
||||
)
|
||||
|
||||
# Verify task structure
|
||||
assert len(extraction_tasks) == 1
|
||||
assert len(enrichment_tasks) == 1
|
||||
|
||||
# Verify extraction task
|
||||
extraction_task = extraction_tasks[0]
|
||||
assert hasattr(extraction_task, 'function')
|
||||
|
||||
# Verify enrichment task
|
||||
enrichment_task = enrichment_tasks[0]
|
||||
assert hasattr(enrichment_task, 'function')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_usage_frequency_update_integration():
|
||||
"""Test the full end-to-end update process."""
|
||||
# Create mock graph with interactions
|
||||
mock_graph = create_interaction_graph(
|
||||
interaction_count=5,
|
||||
target_nodes=['node1', 'node1', 'node2', 'node3', 'node1']
|
||||
)
|
||||
|
||||
# Mock adapter
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.get_node_by_id = AsyncMock(return_value={'properties': {}})
|
||||
mock_adapter.update_node_properties = AsyncMock()
|
||||
|
||||
# Run the full update
|
||||
stats = await run_usage_frequency_update(
|
||||
graph_adapter=mock_adapter,
|
||||
subgraphs=[mock_graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Verify stats
|
||||
assert stats['total_interactions'] == 5
|
||||
assert stats['node_frequencies']['node1'] == 3
|
||||
assert stats['node_frequencies']['node2'] == 1
|
||||
assert stats['node_frequencies']['node3'] == 1
|
||||
|
||||
# Verify adapter was called to update nodes
|
||||
assert mock_adapter.update_node_properties.call_count == 3 # 3 unique nodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_no_used_graph_element_edges():
|
||||
"""Test handling when there are interactions but no proper edges."""
|
||||
graph = CogneeGraph(directed=True)
|
||||
|
||||
# Create interaction node
|
||||
interaction = create_mock_node(
|
||||
'interaction1',
|
||||
{
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'target_node_id': 'node1'
|
||||
}
|
||||
)
|
||||
graph.add_node(interaction)
|
||||
|
||||
# Don't add any edges - interaction is orphaned
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Should find the interaction but no frequencies (no edges)
|
||||
assert result['total_interactions'] == 1
|
||||
assert result['node_frequencies'] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_usage_frequency_alternative_timestamp_field():
|
||||
"""Test that 'created_at' field works as fallback for timestamp."""
|
||||
graph = CogneeGraph(directed=True)
|
||||
|
||||
target = create_mock_node('target1', {'type': 'DocumentChunk'})
|
||||
graph.add_node(target)
|
||||
|
||||
# Use 'created_at' instead of 'timestamp'
|
||||
interaction = create_mock_node(
|
||||
'interaction1',
|
||||
{
|
||||
'type': 'CogneeUserInteraction',
|
||||
'created_at': datetime.now().isoformat() # Alternative field
|
||||
}
|
||||
)
|
||||
graph.add_node(interaction)
|
||||
|
||||
edge = create_mock_edge(interaction, target, 'used_graph_element_to_answer')
|
||||
graph.add_edge(edge)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=1),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Should still work with created_at
|
||||
assert result['total_interactions'] == 1
|
||||
assert 'target1' in result['node_frequencies']
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that all required modules can be imported."""
|
||||
# Mock imports for testing without full Cognee setup
|
||||
try:
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
from cognee.tasks.memify.extract_usage_frequency import (
|
||||
extract_usage_frequency,
|
||||
add_frequency_weights,
|
||||
create_usage_frequency_pipeline,
|
||||
run_usage_frequency_update,
|
||||
run_usage_frequency_update
|
||||
)
|
||||
COGNEE_AVAILABLE = True
|
||||
except ImportError:
|
||||
COGNEE_AVAILABLE = False
|
||||
print("⚠ Cognee not fully available - some tests will be skipped")
|
||||
|
||||
|
||||
class TestUsageFrequencyExtraction(unittest.TestCase):
|
||||
"""Test the core frequency extraction logic."""
|
||||
|
||||
assert extract_usage_frequency is not None
|
||||
assert add_frequency_weights is not None
|
||||
assert create_usage_frequency_pipeline is not None
|
||||
assert run_usage_frequency_update is not None
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
self.skipTest("Cognee modules not available")
|
||||
|
||||
def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5):
|
||||
"""Create a mock graph with interactions and elements."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Create interaction nodes
|
||||
current_time = datetime.now()
|
||||
for i in range(num_interactions):
|
||||
interaction_node = Node(
|
||||
id=f"interaction_{i}",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
'type': 'CogneeUserInteraction',
|
||||
'query_text': f'Test query {i}',
|
||||
'timestamp': int((current_time - timedelta(hours=i)).timestamp() * 1000)
|
||||
}
|
||||
)
|
||||
graph.add_node(interaction_node)
|
||||
|
||||
# Create graph element nodes
|
||||
for i in range(num_elements):
|
||||
element_node = Node(
|
||||
id=f"element_{i}",
|
||||
node_type="DocumentChunk",
|
||||
attributes={
|
||||
'type': 'DocumentChunk',
|
||||
'text': f'Element content {i}'
|
||||
}
|
||||
)
|
||||
graph.add_node(element_node)
|
||||
|
||||
# Create usage edges (interactions reference elements)
|
||||
for i in range(num_interactions):
|
||||
# Each interaction uses 2-3 elements
|
||||
for j in range(2):
|
||||
element_idx = (i + j) % num_elements
|
||||
edge = Edge(
|
||||
node1=graph.get_node(f"interaction_{i}"),
|
||||
node2=graph.get_node(f"element_{element_idx}"),
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={'relationship_type': 'used_graph_element_to_answer'}
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
async def test_basic_frequency_extraction(self):
|
||||
"""Test basic frequency extraction with simple graph."""
|
||||
graph = self.create_mock_graph(num_interactions=3, num_elements=5)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
self.assertIn('node_frequencies', result)
|
||||
self.assertIn('total_interactions', result)
|
||||
self.assertEqual(result['total_interactions'], 3)
|
||||
self.assertGreater(len(result['node_frequencies']), 0)
|
||||
|
||||
async def test_time_window_filtering(self):
|
||||
"""Test that time window correctly filters old interactions."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# Add recent interaction (within window)
|
||||
recent_node = Node(
|
||||
id="recent_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': int(current_time.timestamp() * 1000)
|
||||
}
|
||||
)
|
||||
graph.add_node(recent_node)
|
||||
|
||||
# Add old interaction (outside window)
|
||||
old_node = Node(
|
||||
id="old_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': int((current_time - timedelta(days=10)).timestamp() * 1000)
|
||||
}
|
||||
)
|
||||
graph.add_node(old_node)
|
||||
|
||||
# Add element
|
||||
element = Node(id="element_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'})
|
||||
graph.add_node(element)
|
||||
|
||||
# Add edges
|
||||
graph.add_edge(Edge(
|
||||
node1=recent_node, node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={'relationship_type': 'used_graph_element_to_answer'}
|
||||
))
|
||||
graph.add_edge(Edge(
|
||||
node1=old_node, node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={'relationship_type': 'used_graph_element_to_answer'}
|
||||
))
|
||||
|
||||
# Extract with 7-day window
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Should only count recent interaction
|
||||
self.assertEqual(result['interactions_in_window'], 1)
|
||||
self.assertEqual(result['total_interactions'], 2)
|
||||
|
||||
async def test_threshold_filtering(self):
|
||||
"""Test that minimum threshold filters low-frequency nodes."""
|
||||
graph = self.create_mock_graph(num_interactions=5, num_elements=10)
|
||||
|
||||
# Extract with threshold of 3
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7),
|
||||
min_interaction_threshold=3
|
||||
)
|
||||
|
||||
# Only nodes with 3+ accesses should be included
|
||||
for node_id, freq in result['node_frequencies'].items():
|
||||
self.assertGreaterEqual(freq, 3)
|
||||
|
||||
async def test_element_type_tracking(self):
|
||||
"""Test that element types are properly tracked."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Create interaction
|
||||
interaction = Node(
|
||||
id="interaction_1",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
)
|
||||
graph.add_node(interaction)
|
||||
|
||||
# Create elements of different types
|
||||
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={'type': 'DocumentChunk'})
|
||||
entity = Node(id="entity_1", node_type="Entity", attributes={'type': 'Entity'})
|
||||
|
||||
graph.add_node(chunk)
|
||||
graph.add_node(entity)
|
||||
|
||||
# Add edges
|
||||
for element in [chunk, entity]:
|
||||
graph.add_edge(Edge(
|
||||
node1=interaction, node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={'relationship_type': 'used_graph_element_to_answer'}
|
||||
))
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7)
|
||||
)
|
||||
|
||||
# Check element types were tracked
|
||||
self.assertIn('element_type_frequencies', result)
|
||||
types = result['element_type_frequencies']
|
||||
self.assertIn('DocumentChunk', types)
|
||||
self.assertIn('Entity', types)
|
||||
|
||||
async def test_empty_graph(self):
|
||||
"""Test handling of empty graph."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7)
|
||||
)
|
||||
|
||||
self.assertEqual(result['total_interactions'], 0)
|
||||
self.assertEqual(len(result['node_frequencies']), 0)
|
||||
|
||||
async def test_no_interactions_in_window(self):
|
||||
"""Test handling when all interactions are outside time window."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Add old interaction
|
||||
old_time = datetime.now() - timedelta(days=30)
|
||||
old_interaction = Node(
|
||||
id="old_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
'type': 'CogneeUserInteraction',
|
||||
'timestamp': int(old_time.timestamp() * 1000)
|
||||
}
|
||||
)
|
||||
graph.add_node(old_interaction)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=7)
|
||||
)
|
||||
|
||||
self.assertEqual(result['interactions_in_window'], 0)
|
||||
self.assertEqual(result['total_interactions'], 1)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete workflow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
self.skipTest("Cognee modules not available")
|
||||
|
||||
async def test_end_to_end_workflow(self):
|
||||
"""Test the complete end-to-end frequency tracking workflow."""
|
||||
# This would require a full Cognee setup with database
|
||||
# Skipped in unit tests, run as part of example_usage_frequency_e2e.py
|
||||
self.skipTest("E2E test - run example_usage_frequency_e2e.py instead")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Runner
|
||||
# ============================================================================
|
||||
|
||||
def run_async_test(test_func):
|
||||
"""Helper to run async test functions."""
|
||||
asyncio.run(test_func())
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
print("⚠ Cognee not available - skipping tests")
|
||||
print("Install with: pip install cognee[neo4j]")
|
||||
return
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Usage Frequency Tests")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# Create test suite
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# Add tests
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction))
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("Test Summary")
|
||||
print("=" * 80)
|
||||
print(f"Tests run: {result.testsRun}")
|
||||
print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f"Failures: {len(result.failures)}")
|
||||
print(f"Errors: {len(result.errors)}")
|
||||
print(f"Skipped: {len(result.skipped)}")
|
||||
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
exit(main())
|
||||
|
|
@ -1,324 +1,473 @@
|
|||
# cognee/examples/usage_frequency_example.py
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
End-to-end example demonstrating usage frequency tracking in Cognee.
|
||||
End-to-End Example: Usage Frequency Tracking in Cognee
|
||||
|
||||
This example shows how to:
|
||||
1. Add data and build a knowledge graph
|
||||
2. Run searches with save_interaction=True to track usage
|
||||
3. Extract and apply frequency weights using the memify pipeline
|
||||
4. Query and analyze the frequency data
|
||||
This example demonstrates the complete workflow for tracking and analyzing
|
||||
how frequently different graph elements are accessed through user searches.
|
||||
|
||||
The frequency weights can be used to:
|
||||
- Rank frequently referenced entities higher during retrieval
|
||||
- Adjust scoring for completion strategies
|
||||
- Expose usage metrics in dashboards or audits
|
||||
Features demonstrated:
|
||||
- Setting up a knowledge base
|
||||
- Running searches with interaction tracking (save_interaction=True)
|
||||
- Extracting usage frequencies from interaction data
|
||||
- Applying frequency weights to graph nodes
|
||||
- Analyzing and visualizing the results
|
||||
|
||||
Use cases:
|
||||
- Ranking search results by popularity
|
||||
- Identifying "hot topics" in your knowledge base
|
||||
- Understanding user behavior and interests
|
||||
- Improving retrieval based on usage patterns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.tasks.memify.extract_usage_frequency import (
|
||||
create_usage_frequency_pipeline,
|
||||
run_usage_frequency_update,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
|
||||
|
||||
logger = get_logger("usage_frequency_example")
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 1: Setup and Configuration
|
||||
# ============================================================================
|
||||
|
||||
async def setup_knowledge_base():
|
||||
"""Set up a fresh knowledge base with sample data."""
|
||||
logger.info("Setting up knowledge base...")
|
||||
"""
|
||||
Create a fresh knowledge base with sample content.
|
||||
|
||||
# Reset cognee state for clean slate
|
||||
In a real application, you would:
|
||||
- Load documents from files, databases, or APIs
|
||||
- Process larger datasets
|
||||
- Organize content by datasets/categories
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 1: Setting up knowledge base")
|
||||
print("=" * 80)
|
||||
|
||||
# Reset state for clean demo (optional in production)
|
||||
print("\nResetting Cognee state...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Sample conversation about AI/ML topics
|
||||
conversation = [
|
||||
"Alice discusses machine learning algorithms and their applications in computer vision.",
|
||||
"Bob asks about neural networks and how they differ from traditional algorithms.",
|
||||
"Alice explains deep learning concepts including CNNs and transformers.",
|
||||
"Bob wants more details about neural networks and backpropagation.",
|
||||
"Alice describes reinforcement learning and its use in robotics.",
|
||||
"Bob inquires about natural language processing and transformers.",
|
||||
]
|
||||
|
||||
# Add conversation data and build knowledge graph
|
||||
logger.info("Adding conversation data...")
|
||||
await cognee.add(conversation, dataset_name="ai_ml_conversation")
|
||||
print("✓ Reset complete")
|
||||
|
||||
logger.info("Building knowledge graph (cognify)...")
|
||||
# Sample content: AI/ML educational material
|
||||
documents = [
|
||||
"""
|
||||
Machine Learning Fundamentals:
|
||||
Machine learning is a subset of artificial intelligence that enables systems
|
||||
to learn and improve from experience without being explicitly programmed.
|
||||
The three main types are supervised learning, unsupervised learning, and
|
||||
reinforcement learning.
|
||||
""",
|
||||
"""
|
||||
Neural Networks Explained:
|
||||
Neural networks are computing systems inspired by biological neural networks.
|
||||
They consist of layers of interconnected nodes (neurons) that process information
|
||||
through weighted connections. Deep learning uses neural networks with many layers
|
||||
to automatically learn hierarchical representations of data.
|
||||
""",
|
||||
"""
|
||||
Natural Language Processing:
|
||||
NLP enables computers to understand, interpret, and generate human language.
|
||||
Modern NLP uses transformer architectures like BERT and GPT, which have
|
||||
revolutionized tasks such as translation, summarization, and question answering.
|
||||
""",
|
||||
"""
|
||||
Computer Vision Applications:
|
||||
Computer vision allows machines to interpret visual information from the world.
|
||||
Convolutional neural networks (CNNs) are particularly effective for image
|
||||
recognition, object detection, and image segmentation tasks.
|
||||
""",
|
||||
]
|
||||
|
||||
print(f"\nAdding {len(documents)} documents to knowledge base...")
|
||||
await cognee.add(documents, dataset_name="ai_ml_fundamentals")
|
||||
print("✓ Documents added")
|
||||
|
||||
# Build knowledge graph
|
||||
print("\nBuilding knowledge graph (cognify)...")
|
||||
await cognee.cognify()
|
||||
print("✓ Knowledge graph built")
|
||||
|
||||
logger.info("Knowledge base setup complete")
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
async def simulate_user_searches():
|
||||
"""Simulate multiple user searches to generate interaction data."""
|
||||
logger.info("Simulating user searches with save_interaction=True...")
|
||||
# ============================================================================
|
||||
# STEP 2: Simulate User Searches with Interaction Tracking
|
||||
# ============================================================================
|
||||
|
||||
async def simulate_user_searches(queries: List[str]):
|
||||
"""
|
||||
Simulate users searching the knowledge base.
|
||||
|
||||
# Different queries that will create CogneeUserInteraction nodes
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"Explain neural networks",
|
||||
"Tell me about deep learning",
|
||||
"What are neural networks?", # Repeat to increase frequency
|
||||
"How does machine learning work?",
|
||||
"Describe transformers in NLP",
|
||||
"What is reinforcement learning?",
|
||||
"Explain neural networks again", # Another repeat
|
||||
]
|
||||
|
||||
search_count = 0
|
||||
for query in queries:
|
||||
The key parameter is save_interaction=True, which creates:
|
||||
- CogneeUserInteraction nodes (one per search)
|
||||
- used_graph_element_to_answer edges (connecting queries to relevant nodes)
|
||||
|
||||
Args:
|
||||
queries: List of search queries to simulate
|
||||
|
||||
Returns:
|
||||
Number of successful searches
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 2: Simulating user searches with interaction tracking")
|
||||
print("=" * 80)
|
||||
|
||||
successful_searches = 0
|
||||
|
||||
for i, query in enumerate(queries, 1):
|
||||
print(f"\nSearch {i}/{len(queries)}: '{query}'")
|
||||
try:
|
||||
logger.info(f"Searching: '{query}'")
|
||||
results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=query,
|
||||
save_interaction=True, # Critical: saves interaction to graph
|
||||
save_interaction=True, # ← THIS IS CRITICAL!
|
||||
top_k=5
|
||||
)
|
||||
search_count += 1
|
||||
logger.debug(f"Search completed, got {len(results) if results else 0} results")
|
||||
successful_searches += 1
|
||||
|
||||
# Show snippet of results
|
||||
result_preview = str(results)[:100] if results else "No results"
|
||||
print(f" ✓ Completed ({result_preview}...)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Search failed for '{query}': {e}")
|
||||
|
||||
logger.info(f"Completed {search_count} searches with interactions saved")
|
||||
return search_count
|
||||
|
||||
|
||||
async def retrieve_interaction_graph() -> List[CogneeGraph]:
|
||||
"""Retrieve the graph containing interaction nodes."""
|
||||
logger.info("Retrieving graph with interaction data...")
|
||||
print(f" ✗ Failed: {e}")
|
||||
|
||||
print(f"\n✓ Completed {successful_searches}/{len(queries)} searches")
|
||||
print("=" * 80)
|
||||
|
||||
return successful_searches
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 3: Extract and Apply Usage Frequencies
|
||||
# ============================================================================
|
||||
|
||||
async def extract_and_apply_frequencies(
|
||||
time_window_days: int = 7,
|
||||
min_threshold: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract usage frequencies from interactions and apply them to the graph.
|
||||
|
||||
This function:
|
||||
1. Retrieves the graph with interaction data
|
||||
2. Counts how often each node was accessed
|
||||
3. Writes frequency_weight property back to nodes
|
||||
|
||||
Args:
|
||||
time_window_days: Only count interactions from last N days
|
||||
min_threshold: Minimum accesses to track (filter out rarely used nodes)
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about the frequency update
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 3: Extracting and applying usage frequencies")
|
||||
print("=" * 80)
|
||||
|
||||
# Get graph adapter
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Project the full graph including CogneeUserInteraction nodes
|
||||
# Retrieve graph with interactions
|
||||
print("\nRetrieving graph from database...")
|
||||
graph = CogneeGraph()
|
||||
await graph.project_graph_from_db(
|
||||
adapter=graph_engine,
|
||||
node_properties_to_project=["type", "node_type", "timestamp", "created_at", "text", "name"],
|
||||
edge_properties_to_project=["relationship_type", "timestamp", "created_at"],
|
||||
node_properties_to_project=[
|
||||
"type", "node_type", "timestamp", "created_at",
|
||||
"text", "name", "query_text", "frequency_weight"
|
||||
],
|
||||
edge_properties_to_project=["relationship_type", "timestamp"],
|
||||
directed=True,
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved graph: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
|
||||
print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
|
||||
|
||||
# Count interaction nodes for verification
|
||||
interaction_count = sum(
|
||||
1 for node in graph.nodes.values()
|
||||
if node.attributes.get('type') == 'CogneeUserInteraction' or
|
||||
node.attributes.get('node_type') == 'CogneeUserInteraction'
|
||||
)
|
||||
logger.info(f"Found {interaction_count} CogneeUserInteraction nodes in graph")
|
||||
# Count interaction nodes
|
||||
interaction_nodes = [
|
||||
n for n in graph.nodes.values()
|
||||
if n.attributes.get('type') == 'CogneeUserInteraction' or
|
||||
n.attributes.get('node_type') == 'CogneeUserInteraction'
|
||||
]
|
||||
print(f"✓ Found {len(interaction_nodes)} interaction nodes")
|
||||
|
||||
return [graph]
|
||||
|
||||
|
||||
async def run_frequency_pipeline_method1():
|
||||
"""Method 1: Using the pipeline creation function."""
|
||||
logger.info("\n=== Method 1: Using create_usage_frequency_pipeline ===")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
subgraphs = await retrieve_interaction_graph()
|
||||
|
||||
# Create the pipeline tasks
|
||||
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||
graph_adapter=graph_engine,
|
||||
time_window=timedelta(days=30), # Last 30 days
|
||||
min_interaction_threshold=1, # Count all interactions
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
logger.info("Running extraction tasks...")
|
||||
# Note: In real memify pipeline, these would be executed by the pipeline runner
|
||||
# For this example, we'll execute them manually
|
||||
for task in extraction_tasks:
|
||||
if hasattr(task, 'function'):
|
||||
result = await task.function(
|
||||
subgraphs=subgraphs,
|
||||
time_window=timedelta(days=30),
|
||||
min_interaction_threshold=1
|
||||
)
|
||||
logger.info(f"Extraction result: {result.get('interactions_in_window')} interactions processed")
|
||||
|
||||
logger.info("Running enrichment tasks...")
|
||||
for task in enrichment_tasks:
|
||||
if hasattr(task, 'function'):
|
||||
await task.function(
|
||||
graph_adapter=graph_engine,
|
||||
usage_frequencies=result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_frequency_pipeline_method2():
|
||||
"""Method 2: Using the convenience function."""
|
||||
logger.info("\n=== Method 2: Using run_usage_frequency_update ===")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
subgraphs = await retrieve_interaction_graph()
|
||||
|
||||
# Run the complete pipeline in one call
|
||||
# Run frequency extraction and update
|
||||
print(f"\nExtracting frequencies (time window: {time_window_days} days)...")
|
||||
stats = await run_usage_frequency_update(
|
||||
graph_adapter=graph_engine,
|
||||
subgraphs=subgraphs,
|
||||
time_window=timedelta(days=30),
|
||||
min_interaction_threshold=1
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=time_window_days),
|
||||
min_interaction_threshold=min_threshold
|
||||
)
|
||||
|
||||
logger.info("Frequency update statistics:")
|
||||
logger.info(f" Total interactions: {stats['total_interactions']}")
|
||||
logger.info(f" Interactions in window: {stats['interactions_in_window']}")
|
||||
logger.info(f" Nodes with frequency weights: {len(stats['node_frequencies'])}")
|
||||
logger.info(f" Element types: {stats.get('element_type_frequencies', {})}")
|
||||
print(f"\n✓ Frequency extraction complete!")
|
||||
print(f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}")
|
||||
print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
|
||||
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def analyze_frequency_weights():
|
||||
"""Analyze and display the frequency weights that were added."""
|
||||
logger.info("\n=== Analyzing Frequency Weights ===")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Project graph with frequency weights
|
||||
await graph.project_graph_from_db(
|
||||
adapter=graph_engine,
|
||||
node_properties_to_project=[
|
||||
"type",
|
||||
"node_type",
|
||||
"text",
|
||||
"name",
|
||||
"frequency_weight", # Our added property
|
||||
"frequency_updated_at"
|
||||
],
|
||||
edge_properties_to_project=["relationship_type"],
|
||||
directed=True,
|
||||
)
|
||||
|
||||
# Find nodes with frequency weights
|
||||
weighted_nodes = []
|
||||
for node_id, node in graph.nodes.items():
|
||||
freq_weight = node.attributes.get('frequency_weight')
|
||||
if freq_weight is not None:
|
||||
weighted_nodes.append({
|
||||
'id': node_id,
|
||||
'type': node.attributes.get('type') or node.attributes.get('node_type'),
|
||||
'text': node.attributes.get('text', '')[:100], # First 100 chars
|
||||
'name': node.attributes.get('name', ''),
|
||||
'frequency_weight': freq_weight,
|
||||
'updated_at': node.attributes.get('frequency_updated_at')
|
||||
})
|
||||
|
||||
# Sort by frequency (descending)
|
||||
weighted_nodes.sort(key=lambda x: x['frequency_weight'], reverse=True)
|
||||
|
||||
logger.info(f"\nFound {len(weighted_nodes)} nodes with frequency weights:")
|
||||
logger.info("\nTop 10 Most Frequently Referenced Elements:")
|
||||
logger.info("-" * 80)
|
||||
|
||||
for i, node in enumerate(weighted_nodes[:10], 1):
|
||||
logger.info(f"\n{i}. Frequency: {node['frequency_weight']}")
|
||||
logger.info(f" Type: {node['type']}")
|
||||
logger.info(f" Name: {node['name']}")
|
||||
logger.info(f" Text: {node['text']}")
|
||||
logger.info(f" ID: {node['id'][:50]}...")
|
||||
|
||||
return weighted_nodes
|
||||
# ============================================================================
|
||||
# STEP 4: Analyze and Display Results
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def demonstrate_retrieval_with_frequencies():
|
||||
"""Demonstrate how frequency weights can be used in retrieval."""
|
||||
logger.info("\n=== Demonstrating Retrieval with Frequency Weights ===")
|
||||
async def analyze_results(stats: Dict[str, Any]):
|
||||
"""
|
||||
Analyze and display the frequency tracking results.
|
||||
|
||||
# This is a conceptual demonstration of how frequency weights
|
||||
# could be used to boost search results
|
||||
Shows:
|
||||
- Top most frequently accessed nodes
|
||||
- Element type distribution
|
||||
- Verification that weights were written to database
|
||||
|
||||
query = "neural networks"
|
||||
logger.info(f"Searching for: '{query}'")
|
||||
Args:
|
||||
stats: Statistics from frequency extraction
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 4: Analyzing usage frequency results")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# Standard search
|
||||
standard_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=query,
|
||||
save_interaction=False, # Don't add more interactions
|
||||
top_k=5
|
||||
# Display top nodes by frequency
|
||||
if stats['node_frequencies']:
|
||||
print("\n📊 Top 10 Most Frequently Accessed Elements:")
|
||||
print("-" * 80)
|
||||
|
||||
sorted_nodes = sorted(
|
||||
stats['node_frequencies'].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
logger.info(f"Standard search returned {len(standard_results) if standard_results else 0} results")
|
||||
# Get graph to display node details
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = CogneeGraph()
|
||||
await graph.project_graph_from_db(
|
||||
adapter=graph_engine,
|
||||
node_properties_to_project=["type", "text", "name"],
|
||||
edge_properties_to_project=[],
|
||||
directed=True,
|
||||
)
|
||||
|
||||
# Note: To actually use frequency_weight in scoring, you would need to:
|
||||
# 1. Modify the retrieval/ranking logic to consider frequency_weight
|
||||
# 2. Add frequency_weight as a scoring factor in the completion strategy
|
||||
# 3. Use it in analytics dashboards to show popular topics
|
||||
|
||||
logger.info("\nFrequency weights can now be used for:")
|
||||
logger.info(" - Boosting frequently-accessed nodes in search rankings")
|
||||
logger.info(" - Adjusting triplet importance scores")
|
||||
logger.info(" - Building usage analytics dashboards")
|
||||
logger.info(" - Identifying 'hot' topics in the knowledge graph")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Demonstration search failed: {e}")
|
||||
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
|
||||
node = graph.get_node(node_id)
|
||||
if node:
|
||||
node_type = node.attributes.get('type', 'Unknown')
|
||||
text = node.attributes.get('text') or node.attributes.get('name') or ''
|
||||
text_preview = text[:60] + "..." if len(text) > 60 else text
|
||||
|
||||
print(f"\n{i}. Frequency: {frequency} accesses")
|
||||
print(f" Type: {node_type}")
|
||||
print(f" Content: {text_preview}")
|
||||
else:
|
||||
print(f"\n{i}. Frequency: {frequency} accesses")
|
||||
print(f" Node ID: {node_id[:50]}...")
|
||||
|
||||
# Display element type distribution
|
||||
if stats.get('element_type_frequencies'):
|
||||
print("\n\n📈 Element Type Distribution:")
|
||||
print("-" * 80)
|
||||
type_dist = stats['element_type_frequencies']
|
||||
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {elem_type}: {count} accesses")
|
||||
|
||||
# Verify weights in database (Neo4j only)
|
||||
print("\n\n🔍 Verifying weights in database...")
|
||||
print("-" * 80)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
adapter_type = type(graph_engine).__name__
|
||||
|
||||
if adapter_type == 'Neo4jAdapter':
|
||||
try:
|
||||
result = await graph_engine.query("""
|
||||
MATCH (n)
|
||||
WHERE n.frequency_weight IS NOT NULL
|
||||
RETURN count(n) as weighted_count
|
||||
""")
|
||||
|
||||
count = result[0]['weighted_count'] if result else 0
|
||||
if count > 0:
|
||||
print(f"✓ {count} nodes have frequency_weight in Neo4j database")
|
||||
|
||||
# Show sample
|
||||
sample = await graph_engine.query("""
|
||||
MATCH (n)
|
||||
WHERE n.frequency_weight IS NOT NULL
|
||||
RETURN n.frequency_weight as weight, labels(n) as labels
|
||||
ORDER BY n.frequency_weight DESC
|
||||
LIMIT 3
|
||||
""")
|
||||
|
||||
print("\nSample weighted nodes:")
|
||||
for row in sample:
|
||||
print(f" - Weight: {row['weight']}, Type: {row['labels']}")
|
||||
else:
|
||||
print("⚠ No nodes with frequency_weight found in database")
|
||||
except Exception as e:
|
||||
print(f"Could not verify in Neo4j: {e}")
|
||||
else:
|
||||
print(f"Database verification not implemented for {adapter_type}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 5: Demonstrate Usage in Retrieval
|
||||
# ============================================================================
|
||||
|
||||
async def demonstrate_retrieval_usage():
|
||||
"""
|
||||
Demonstrate how frequency weights can be used in retrieval.
|
||||
|
||||
Note: This is a conceptual demonstration. To actually use frequency
|
||||
weights in ranking, you would need to modify the retrieval/completion
|
||||
strategies to incorporate the frequency_weight property.
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 5: How to use frequency weights in retrieval")
|
||||
print("=" * 80)
|
||||
|
||||
print("""
|
||||
Frequency weights can be used to improve search results:
|
||||
|
||||
1. RANKING BOOST:
|
||||
- Multiply relevance scores by frequency_weight
|
||||
- Prioritize frequently accessed nodes in results
|
||||
|
||||
2. COMPLETION STRATEGIES:
|
||||
- Adjust triplet importance based on usage
|
||||
- Filter out rarely accessed information
|
||||
|
||||
3. ANALYTICS:
|
||||
- Track trending topics over time
|
||||
- Understand user interests and behavior
|
||||
- Identify knowledge gaps (low-frequency nodes)
|
||||
|
||||
4. ADAPTIVE RETRIEVAL:
|
||||
- Personalize results based on team usage patterns
|
||||
- Surface popular answers faster
|
||||
|
||||
Example Cypher query with frequency boost (Neo4j):
|
||||
|
||||
MATCH (n)
|
||||
WHERE n.text CONTAINS $search_term
|
||||
RETURN n, n.frequency_weight as boost
|
||||
ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC
|
||||
LIMIT 10
|
||||
|
||||
To integrate this into Cognee, you would modify the completion
|
||||
strategy to include frequency_weight in the scoring function.
|
||||
""")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN: Run Complete Example
|
||||
# ============================================================================
|
||||
|
||||
async def main():
|
||||
"""Main execution flow."""
|
||||
logger.info("=" * 80)
|
||||
logger.info("Usage Frequency Tracking Example")
|
||||
logger.info("=" * 80)
|
||||
"""
|
||||
Run the complete end-to-end usage frequency tracking example.
|
||||
"""
|
||||
print("\n")
|
||||
print("╔" + "=" * 78 + "╗")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("║" + " Usage Frequency Tracking - End-to-End Example".center(78) + "║")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("╚" + "=" * 78 + "╝")
|
||||
print("\n")
|
||||
|
||||
# Configuration check
|
||||
print("Configuration:")
|
||||
print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}")
|
||||
print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}")
|
||||
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
|
||||
|
||||
# Verify LLM key is set
|
||||
if not os.getenv('LLM_API_KEY') or os.getenv('LLM_API_KEY') == 'sk-your-key-here':
|
||||
print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
|
||||
print(" Set your API key to run searches")
|
||||
return
|
||||
|
||||
print("\n")
|
||||
|
||||
try:
|
||||
# Step 1: Setup knowledge base
|
||||
# Step 1: Setup
|
||||
await setup_knowledge_base()
|
||||
|
||||
# Step 2: Simulate user searches with save_interaction=True
|
||||
search_count = await simulate_user_searches()
|
||||
# Step 2: Simulate searches
|
||||
# Note: Repeat queries increase frequency for those topics
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"Explain neural networks",
|
||||
"How does deep learning work?",
|
||||
"Tell me about neural networks", # Repeat - increases frequency
|
||||
"What are transformers in NLP?",
|
||||
"Explain neural networks again", # Another repeat
|
||||
"How does computer vision work?",
|
||||
"What is reinforcement learning?",
|
||||
"Tell me more about neural networks", # Third repeat
|
||||
]
|
||||
|
||||
if search_count == 0:
|
||||
logger.warning("No searches completed - cannot demonstrate frequency tracking")
|
||||
successful_searches = await simulate_user_searches(queries)
|
||||
|
||||
if successful_searches == 0:
|
||||
print("⚠ No searches completed - cannot demonstrate frequency tracking")
|
||||
return
|
||||
|
||||
# Step 3: Run frequency extraction and enrichment
|
||||
# You can use either method - both accomplish the same thing
|
||||
# Step 3: Extract frequencies
|
||||
stats = await extract_and_apply_frequencies(
|
||||
time_window_days=7,
|
||||
min_threshold=1
|
||||
)
|
||||
|
||||
# Option A: Using the convenience function (recommended)
|
||||
stats = await run_frequency_pipeline_method2()
|
||||
# Step 4: Analyze results
|
||||
await analyze_results(stats)
|
||||
|
||||
# Option B: Using the pipeline creation function (for custom pipelines)
|
||||
# stats = await run_frequency_pipeline_method1()
|
||||
|
||||
# Step 4: Analyze the results
|
||||
weighted_nodes = await analyze_frequency_weights()
|
||||
|
||||
# Step 5: Demonstrate retrieval usage
|
||||
await demonstrate_retrieval_with_frequencies()
|
||||
# Step 5: Show usage examples
|
||||
await demonstrate_retrieval_usage()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("SUMMARY")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Searches performed: {search_count}")
|
||||
logger.info(f"Interactions tracked: {stats.get('interactions_in_window', 0)}")
|
||||
logger.info(f"Nodes weighted: {len(weighted_nodes)}")
|
||||
logger.info(f"Time window: {stats.get('time_window_days', 0)} days")
|
||||
logger.info("\nFrequency weights have been added to the graph!")
|
||||
logger.info("These can now be used in retrieval, ranking, and analytics.")
|
||||
logger.info("=" * 80)
|
||||
print("\n")
|
||||
print("╔" + "=" * 78 + "╗")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("║" + " Example Complete!".center(78) + "║")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("╚" + "=" * 78 + "╝")
|
||||
print("\n")
|
||||
|
||||
print("Summary:")
|
||||
print(f" ✓ Documents added: 4")
|
||||
print(f" ✓ Searches performed: {successful_searches}")
|
||||
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
|
||||
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
|
||||
|
||||
print("\nNext steps:")
|
||||
print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph")
|
||||
print(" 2. Modify retrieval strategies to use frequency_weight")
|
||||
print(" 3. Build analytics dashboards using element_type_frequencies")
|
||||
print(" 4. Run periodic frequency updates to track trends over time")
|
||||
|
||||
print("\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Example failed: {e}", exc_info=True)
|
||||
raise
|
||||
print(f"\n✗ Example failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue