This commit implements a comprehensive optimization of the Maximal Marginal Relevance (MMR) calculation in the search utilities. The key improvements include: ## Algorithm Improvements - **True MMR Implementation**: Replaced the previous diversity-aware scoring with proper iterative MMR algorithm that greedily selects documents one at a time - **Vectorized Operations**: Leveraged numpy's optimized BLAS operations through matrix multiplication instead of individual dot products - **Adaptive Strategy**: Uses different optimization strategies for small (≤100) and large datasets to balance performance and memory usage ## Performance Optimizations - **Memory Efficiency**: Reduced memory complexity from O(n²) to O(n) for large datasets - **BLAS Optimization**: Proper use of matrix multiplication leverages optimized BLAS libraries - **Batch Normalization**: Added `normalize_embeddings_batch()` for efficient L2 normalization of multiple embeddings at once - **Early Termination**: Stops selection when no candidates meet minimum score threshold ## Key Changes - `maximal_marginal_relevance()`: Complete rewrite with proper iterative MMR algorithm - `normalize_embeddings_batch()`: New function for efficient batch normalization - `_mmr_small_dataset()`: Optimized implementation for small datasets using precomputed similarity matrices - Added comprehensive test suite with 9 test cases covering edge cases, correctness, and performance scenarios ## Benefits - **Correctness**: Now implements true MMR algorithm instead of approximate diversity scoring - **Memory Usage**: O(n) memory complexity vs O(n²) for the original implementation - **Scalability**: Better performance characteristics for large datasets - **Maintainability**: Cleaner, more readable code with comprehensive test coverage 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
340 lines
13 KiB
Python
340 lines
13 KiB
Python
from unittest.mock import AsyncMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from graphiti_core.nodes import EntityNode
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.search.search_utils import hybrid_node_search, maximal_marginal_relevance, normalize_embeddings_batch
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_node_search_deduplication():
|
|
# Mock the database driver
|
|
mock_driver = AsyncMock()
|
|
|
|
# Mock the node_fulltext_search and entity_similarity_search functions
|
|
with (
|
|
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
|
|
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
|
|
):
|
|
# Set up mock return values
|
|
mock_fulltext_search.side_effect = [
|
|
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
|
|
[EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
|
|
]
|
|
mock_similarity_search.side_effect = [
|
|
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
|
|
[EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
|
|
]
|
|
|
|
# Call the function with test data
|
|
queries = ['Alice', 'Bob']
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
|
|
|
# Assertions
|
|
assert len(results) == 3
|
|
assert set(node.uuid for node in results) == {'1', '2', '3'}
|
|
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
|
|
|
# Verify that the mock functions were called correctly
|
|
assert mock_fulltext_search.call_count == 2
|
|
assert mock_similarity_search.call_count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_node_search_empty_results():
|
|
mock_driver = AsyncMock()
|
|
|
|
with (
|
|
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
|
|
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
|
|
):
|
|
mock_fulltext_search.return_value = []
|
|
mock_similarity_search.return_value = []
|
|
|
|
queries = ['NonExistent']
|
|
embeddings = [[0.1, 0.2, 0.3]]
|
|
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
|
|
|
assert len(results) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_node_search_only_fulltext():
|
|
mock_driver = AsyncMock()
|
|
|
|
with (
|
|
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
|
|
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
|
|
):
|
|
mock_fulltext_search.return_value = [
|
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
|
|
]
|
|
mock_similarity_search.return_value = []
|
|
|
|
queries = ['Alice']
|
|
embeddings = []
|
|
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
|
|
|
assert len(results) == 1
|
|
assert results[0].name == 'Alice'
|
|
assert mock_fulltext_search.call_count == 1
|
|
assert mock_similarity_search.call_count == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_node_search_with_limit():
|
|
mock_driver = AsyncMock()
|
|
|
|
with (
|
|
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
|
|
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
|
|
):
|
|
mock_fulltext_search.return_value = [
|
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
|
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
|
|
]
|
|
mock_similarity_search.return_value = [
|
|
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
|
|
EntityNode(
|
|
uuid='4',
|
|
name='David',
|
|
labels=['Entity'],
|
|
group_id='1',
|
|
),
|
|
]
|
|
|
|
queries = ['Test']
|
|
embeddings = [[0.1, 0.2, 0.3]]
|
|
limit = 1
|
|
results = await hybrid_node_search(
|
|
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
|
|
)
|
|
|
|
# We expect 4 results because the limit is applied per search method
|
|
# before deduplication, and we're not actually limiting the results
|
|
# in the hybrid_node_search function itself
|
|
assert len(results) == 4
|
|
assert mock_fulltext_search.call_count == 1
|
|
assert mock_similarity_search.call_count == 1
|
|
# Verify that the limit was passed to the search functions
|
|
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2)
|
|
mock_similarity_search.assert_called_with(
|
|
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_node_search_with_limit_and_duplicates():
|
|
mock_driver = AsyncMock()
|
|
|
|
with (
|
|
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
|
|
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
|
|
):
|
|
mock_fulltext_search.return_value = [
|
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
|
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
|
|
]
|
|
mock_similarity_search.return_value = [
|
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate
|
|
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
|
|
]
|
|
|
|
queries = ['Test']
|
|
embeddings = [[0.1, 0.2, 0.3]]
|
|
limit = 2
|
|
results = await hybrid_node_search(
|
|
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
|
|
)
|
|
|
|
# We expect 3 results because:
|
|
# 1. The limit of 2 is applied to each search method
|
|
# 2. We get 2 results from fulltext and 2 from similarity
|
|
# 3. One result is a duplicate (Alice), so it's only included once
|
|
assert len(results) == 3
|
|
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
|
assert mock_fulltext_search.call_count == 1
|
|
assert mock_similarity_search.call_count == 1
|
|
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4)
|
|
mock_similarity_search.assert_called_with(
|
|
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
|
|
)
|
|
|
|
|
|
def test_normalize_embeddings_batch():
|
|
"""Test batch normalization of embeddings."""
|
|
# Test normal case
|
|
embeddings = np.array([[3.0, 4.0], [1.0, 0.0], [0.0, 0.0]])
|
|
normalized = normalize_embeddings_batch(embeddings)
|
|
|
|
# Check that vectors are normalized
|
|
assert np.allclose(normalized[0], [0.6, 0.8]) # 3/5, 4/5
|
|
assert np.allclose(normalized[1], [1.0, 0.0]) # Already normalized
|
|
assert np.allclose(normalized[2], [0.0, 0.0]) # Zero vector stays zero
|
|
|
|
# Check that norms are 1 (except for zero vector)
|
|
norms = np.linalg.norm(normalized, axis=1)
|
|
assert np.allclose(norms[0], 1.0)
|
|
assert np.allclose(norms[1], 1.0)
|
|
assert np.allclose(norms[2], 0.0) # Zero vector
|
|
|
|
|
|
def test_maximal_marginal_relevance_empty_candidates():
|
|
"""Test MMR with empty candidates."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {}
|
|
|
|
result = maximal_marginal_relevance(query, candidates)
|
|
assert result == []
|
|
|
|
|
|
def test_maximal_marginal_relevance_single_candidate():
|
|
"""Test MMR with single candidate."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {"doc1": [1.0, 0.0, 0.0]}
|
|
|
|
result = maximal_marginal_relevance(query, candidates)
|
|
assert result == ["doc1"]
|
|
|
|
|
|
def test_maximal_marginal_relevance_basic_functionality():
|
|
"""Test basic MMR functionality with multiple candidates."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0], # Most relevant to query
|
|
"doc2": [0.0, 1.0, 0.0], # Orthogonal to query
|
|
"doc3": [0.8, 0.0, 0.0], # Similar to query but less relevant
|
|
}
|
|
|
|
result = maximal_marginal_relevance(query, candidates, mmr_lambda=1.0) # Only relevance
|
|
# Should select most relevant first
|
|
assert result[0] == "doc1"
|
|
|
|
result = maximal_marginal_relevance(query, candidates, mmr_lambda=0.0) # Only diversity
|
|
# With pure diversity, should still select most relevant first, then most diverse
|
|
assert result[0] == "doc1" # First selection is always most relevant
|
|
assert result[1] == "doc2" # Most diverse from doc1
|
|
|
|
|
|
def test_maximal_marginal_relevance_diversity_effect():
|
|
"""Test that MMR properly balances relevance and diversity."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0], # Most relevant
|
|
"doc2": [0.9, 0.0, 0.0], # Very similar to doc1, high relevance
|
|
"doc3": [0.0, 1.0, 0.0], # Orthogonal, lower relevance but high diversity
|
|
}
|
|
|
|
# With high lambda (favor relevance), should select doc1, then doc2
|
|
result_relevance = maximal_marginal_relevance(query, candidates, mmr_lambda=0.9)
|
|
assert result_relevance[0] == "doc1"
|
|
assert result_relevance[1] == "doc2"
|
|
|
|
# With low lambda (favor diversity), should select doc1, then doc3
|
|
result_diversity = maximal_marginal_relevance(query, candidates, mmr_lambda=0.1)
|
|
assert result_diversity[0] == "doc1"
|
|
assert result_diversity[1] == "doc3"
|
|
|
|
|
|
def test_maximal_marginal_relevance_min_score_threshold():
|
|
"""Test MMR with minimum score threshold."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0], # High relevance
|
|
"doc2": [0.0, 1.0, 0.0], # Low relevance
|
|
"doc3": [-1.0, 0.0, 0.0], # Negative relevance
|
|
}
|
|
|
|
# With high min_score, should only return highly relevant documents
|
|
result = maximal_marginal_relevance(query, candidates, min_score=0.5)
|
|
assert len(result) == 1
|
|
assert result[0] == "doc1"
|
|
|
|
# With low min_score, should return more documents
|
|
result = maximal_marginal_relevance(query, candidates, min_score=-0.5)
|
|
assert len(result) >= 2
|
|
|
|
|
|
def test_maximal_marginal_relevance_max_results():
|
|
"""Test MMR with maximum results limit."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0],
|
|
"doc2": [0.8, 0.0, 0.0],
|
|
"doc3": [0.6, 0.0, 0.0],
|
|
"doc4": [0.4, 0.0, 0.0],
|
|
}
|
|
|
|
# Limit to 2 results
|
|
result = maximal_marginal_relevance(query, candidates, max_results=2)
|
|
assert len(result) == 2
|
|
assert result[0] == "doc1" # Most relevant
|
|
|
|
# Limit to more than available
|
|
result = maximal_marginal_relevance(query, candidates, max_results=10)
|
|
assert len(result) == 4 # Should return all available
|
|
|
|
|
|
def test_maximal_marginal_relevance_deterministic():
|
|
"""Test that MMR returns deterministic results."""
|
|
query = [1.0, 0.0, 0.0]
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0],
|
|
"doc2": [0.0, 1.0, 0.0],
|
|
"doc3": [0.0, 0.0, 1.0],
|
|
}
|
|
|
|
# Run multiple times to ensure deterministic behavior
|
|
results = []
|
|
for _ in range(5):
|
|
result = maximal_marginal_relevance(query, candidates)
|
|
results.append(result)
|
|
|
|
# All results should be identical
|
|
for result in results[1:]:
|
|
assert result == results[0]
|
|
|
|
|
|
def test_maximal_marginal_relevance_normalized_inputs():
|
|
"""Test that MMR handles both normalized and non-normalized inputs correctly."""
|
|
query = [3.0, 4.0] # Non-normalized
|
|
candidates = {
|
|
"doc1": [6.0, 8.0], # Same direction as query, non-normalized
|
|
"doc2": [0.6, 0.8], # Same direction as query, normalized
|
|
"doc3": [0.0, 1.0], # Orthogonal, normalized
|
|
}
|
|
|
|
result = maximal_marginal_relevance(query, candidates)
|
|
|
|
# Both doc1 and doc2 should be equally relevant (same direction)
|
|
# The algorithm should handle normalization internally
|
|
assert result[0] in ["doc1", "doc2"]
|
|
assert len(result) == 3
|
|
|
|
|
|
def test_maximal_marginal_relevance_edge_cases():
|
|
"""Test MMR with edge cases."""
|
|
query = [0.0, 0.0, 0.0] # Zero query vector
|
|
candidates = {
|
|
"doc1": [1.0, 0.0, 0.0],
|
|
"doc2": [0.0, 1.0, 0.0],
|
|
}
|
|
|
|
# Should still work with zero query (all similarities will be 0)
|
|
result = maximal_marginal_relevance(query, candidates)
|
|
assert len(result) == 2
|
|
|
|
# Test with identical candidates
|
|
candidates_identical = {
|
|
"doc1": [1.0, 0.0, 0.0],
|
|
"doc2": [1.0, 0.0, 0.0],
|
|
"doc3": [1.0, 0.0, 0.0],
|
|
}
|
|
query = [1.0, 0.0, 0.0]
|
|
|
|
result = maximal_marginal_relevance(query, candidates_identical, mmr_lambda=0.5)
|
|
# Should select only one due to high similarity penalty
|
|
assert len(result) >= 1
|