Merge branch 'dev' into feature/cog-3213-docs-set-up-guide-script-tests

This commit is contained in:
Andrej Milicevic 2025-12-17 13:52:45 +01:00
commit 929d88557e
13 changed files with 5657 additions and 0 deletions

View file

@ -0,0 +1,183 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of chunk context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = ChunksRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 2
assert context[0]["text"] == "Steve Rodger"
assert context[1]["text"] == "Mike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty list is returned when no chunks are found."""
mock_vector_engine.search.return_value = []
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(3)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Chunk {i}"}
mock_vector_engine.search.return_value = mock_results
retriever = ChunksRetriever(top_k=3)
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
@pytest.mark.asyncio
async def test_get_completion_with_context(mock_vector_engine):
"""Test get_completion returns provided context."""
retriever = ChunksRetriever()
provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}]
completion = await retriever.get_completion("test query", context=provided_context)
assert completion == provided_context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "Steve Rodger"}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query")
assert len(completion) == 1
assert completion[0]["text"] == "Steve Rodger"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test ChunksRetriever initialization with defaults."""
retriever = ChunksRetriever()
assert retriever.top_k == 5
@pytest.mark.asyncio
async def test_init_custom_top_k():
"""Test ChunksRetriever initialization with custom top_k."""
retriever = ChunksRetriever(top_k=10)
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test ChunksRetriever initialization with None top_k."""
retriever = ChunksRetriever(top_k=None)
assert retriever.top_k is None
@pytest.mark.asyncio
async def test_get_context_empty_payload(mock_vector_engine):
"""Test get_context handles empty payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 1
assert context[0] == {}
@pytest.mark.asyncio
async def test_get_completion_with_session_id(mock_vector_engine):
"""Test get_completion with session_id parameter."""
mock_result = MagicMock()
mock_result.payload = {"text": "Steve Rodger"}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", session_id="test_session")
assert len(completion) == 1
assert completion[0]["text"] == "Steve Rodger"

View file

@ -0,0 +1,492 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.context_global_variables import session_user
import importlib
def create_mock_cache_engine(qa_history=None):
mock_cache = AsyncMock()
if qa_history is None:
qa_history = []
mock_cache.get_latest_qa = AsyncMock(return_value=qa_history)
mock_cache.add_qa = AsyncMock(return_value=None)
return mock_cache
def create_mock_user():
mock_user = MagicMock()
mock_user.id = "test-user-id-123"
return mock_user
class TestConversationHistoryUtils:
@pytest.mark.asyncio
async def test_get_conversation_history_returns_empty_when_no_history(self):
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
from cognee.modules.retrieval.utils.session_cache import get_conversation_history
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_formats_history_correctly(self):
"""Test get_conversation_history formats Q&A history with correct structure."""
user = create_mock_user()
session_user.set(user)
mock_history = [
{
"time": "2024-01-15 10:30:45",
"question": "What is AI?",
"context": "AI is artificial intelligence",
"answer": "AI stands for Artificial Intelligence",
}
]
mock_cache = create_mock_cache_engine(mock_history)
# Import the real module to patch safely
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert "Previous conversation:" in result
assert "[2024-01-15 10:30:45]" in result
assert "QUESTION: What is AI?" in result
assert "CONTEXT: AI is artificial intelligence" in result
assert "ANSWER: AI stands for Artificial Intelligence" in result
@pytest.mark.asyncio
async def test_save_to_session_cache_saves_correctly(self):
"""Test save_conversation_history calls add_qa with correct parameters."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="What is Python?",
context_summary="Python is a programming language",
answer="Python is a high-level programming language",
session_id="my_session",
)
assert result is True
mock_cache.add_qa.assert_called_once()
call_kwargs = mock_cache.add_qa.call_args.kwargs
assert call_kwargs["question"] == "What is Python?"
assert call_kwargs["context"] == "Python is a programming language"
assert call_kwargs["answer"] == "Python is a high-level programming language"
assert call_kwargs["session_id"] == "my_session"
@pytest.mark.asyncio
async def test_save_to_session_cache_uses_default_session_when_none(self):
"""Test save_conversation_history uses 'default_session' when session_id is None."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
session_id=None,
)
assert result is True
call_kwargs = mock_cache.add_qa.call_args.kwargs
assert call_kwargs["session_id"] == "default_session"
@pytest.mark.asyncio
async def test_save_conversation_history_no_user_id(self):
"""Test save_conversation_history returns False when user_id is None."""
session_user.set(None)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_caching_disabled(self):
"""Test save_conversation_history returns False when caching is disabled."""
user = create_mock_user()
session_user.set(user)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = False
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_cache_engine_none(self):
"""Test save_conversation_history returns False when cache_engine is None."""
user = create_mock_user()
session_user.set(user)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=None):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_cache_connection_error(self):
"""Test save_conversation_history handles CacheConnectionError gracefully."""
user = create_mock_user()
session_user.set(user)
from cognee.infrastructure.databases.exceptions import CacheConnectionError
mock_cache = create_mock_cache_engine([])
mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_generic_exception(self):
"""Test save_conversation_history handles generic exceptions gracefully."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_get_conversation_history_no_user_id(self):
"""Test get_conversation_history returns empty string when user_id is None."""
session_user.set(None)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_caching_disabled(self):
"""Test get_conversation_history returns empty string when caching is disabled."""
user = create_mock_user()
session_user.set(user)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = False
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_default_session(self):
"""Test get_conversation_history uses 'default_session' when session_id is None."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
await get_conversation_history(session_id=None)
mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session")
@pytest.mark.asyncio
async def test_get_conversation_history_cache_engine_none(self):
"""Test get_conversation_history returns empty string when cache_engine is None."""
user = create_mock_user()
session_user.set(user)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=None):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_cache_connection_error(self):
"""Test get_conversation_history handles CacheConnectionError gracefully."""
user = create_mock_user()
session_user.set(user)
from cognee.infrastructure.databases.exceptions import CacheConnectionError
mock_cache = create_mock_cache_engine([])
mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_generic_exception(self):
"""Test get_conversation_history handles generic exceptions gracefully."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_missing_keys(self):
"""Test get_conversation_history handles missing keys in history entries."""
user = create_mock_user()
session_user.set(user)
mock_history = [
{
"time": "2024-01-15 10:30:45",
"question": "What is AI?",
},
{
"context": "AI is artificial intelligence",
"answer": "AI stands for Artificial Intelligence",
},
{},
]
mock_cache = create_mock_cache_engine(mock_history)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert "Previous conversation:" in result
assert "[2024-01-15 10:30:45]" in result
assert "QUESTION: What is AI?" in result
assert "Unknown time" in result
assert "CONTEXT: AI is artificial intelligence" in result
assert "ANSWER: AI stands for Artificial Intelligence" in result

View file

@ -0,0 +1,469 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
@pytest.mark.asyncio
async def test_get_triplets_inherited(mock_edge):
"""Test that get_triplets is inherited from parent class."""
retriever = GraphCompletionContextExtensionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
):
triplets = await retriever.get_triplets("test query")
assert len(triplets) == 1
assert triplets[0] == mock_edge
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionContextExtensionRetriever initialization with defaults."""
retriever = GraphCompletionContextExtensionRetriever()
assert retriever.top_k == 5
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionContextExtensionRetriever initialization with custom parameters."""
retriever = GraphCompletionContextExtensionRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
node_type=str,
node_name=["node1"],
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=5.0,
)
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.node_type is str
assert retriever.node_name == ["node1"]
assert retriever.save_interaction is True
assert retriever.wide_search_top_k == 200
assert retriever.triplet_distance_penalty == 5.0
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context_extension_rounds=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_context_extension_rounds(mock_edge):
"""Test get_completion with multiple context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=["Resolved context", "Extended context"], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Query for extension, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context_extension_rounds=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_context_extension_stops_early(mock_edge):
"""Test get_completion stops early when no new triplets found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# When get_context returns same triplets, the loop should stop early
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=4
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"test query", session_id="test_session", context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction(mock_edge):
"""Test get_completion with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 1
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
TestModel(answer="Test answer"),
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", response_model=TestModel, context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_edge):
"""Test get_completion with session config but no user ID."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query", context_extension_rounds=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_zero_extension_rounds(mock_edge):
"""Test get_completion with zero context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context_extension_rounds=0)
assert isinstance(completion, list)
assert len(completion) == 1

View file

@ -0,0 +1,688 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
from cognee.modules.retrieval.graph_completion_cot_retriever import (
GraphCompletionCotRetriever,
_as_answer_text,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
@pytest.mark.asyncio
async def test_get_triplets_inherited(mock_edge):
"""Test that get_triplets is inherited from parent class."""
retriever = GraphCompletionCotRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
):
triplets = await retriever.get_triplets("test query")
assert len(triplets) == 1
assert triplets[0] == mock_edge
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
retriever = GraphCompletionCotRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
validation_user_prompt_path="custom_validation_user.txt",
validation_system_prompt_path="custom_validation_system.txt",
followup_system_prompt_path="custom_followup_system.txt",
followup_user_prompt_path="custom_followup_user.txt",
)
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionCotRetriever initialization with defaults."""
retriever = GraphCompletionCotRetriever()
assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
@pytest.mark.asyncio
async def test_run_cot_completion_round_zero_with_context(mock_edge):
"""Test _run_cot_completion round 0 with provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
max_iter=1,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
@pytest.mark.asyncio
async def test_run_cot_completion_round_zero_without_context(mock_edge):
"""Test _run_cot_completion round 0 without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=None,
max_iter=1,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
@pytest.mark.asyncio
async def test_run_cot_completion_multiple_rounds(mock_edge):
"""Test _run_cot_completion with multiple rounds."""
retriever = GraphCompletionCotRetriever()
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=[
"validation_result",
"followup_question",
"validation_result2",
"followup_question2",
],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
max_iter=2,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
@pytest.mark.asyncio
async def test_run_cot_completion_with_conversation_history(mock_edge):
"""Test _run_cot_completion with conversation history."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
) as mock_generate,
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
conversation_history="Previous conversation",
max_iter=1,
)
assert completion == "Generated answer"
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") == "Previous conversation"
@pytest.mark.asyncio
async def test_run_cot_completion_with_response_model(mock_edge):
"""Test _run_cot_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
response_model=TestModel,
max_iter=1,
)
assert isinstance(completion, TestModel)
assert completion.answer == "Test answer"
@pytest.mark.asyncio
async def test_run_cot_completion_empty_conversation_history(mock_edge):
"""Test _run_cot_completion with empty conversation history."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
) as mock_generate,
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
conversation_history="",
max_iter=1,
)
assert completion == "Generated answer"
# Verify conversation_history was passed as None when empty
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") is None
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"test query", session_id="test_session", max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction(mock_edge):
"""Test get_completion with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionCotRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# Pass context so save_interaction condition is met
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", response_model=TestModel, max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_edge):
"""Test get_completion with session config but no user ID."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query", max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_context(mock_edge):
"""Test get_completion with save_interaction but no context provided."""
retriever = GraphCompletionCotRetriever(save_interaction=True)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=None, max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_as_answer_text_with_typeerror():
"""Test _as_answer_text handles TypeError when json.dumps fails."""
non_serializable = {1, 2, 3}
result = _as_answer_text(non_serializable)
assert isinstance(result, str)
assert result == str(non_serializable)
@pytest.mark.asyncio
async def test_as_answer_text_with_string():
"""Test _as_answer_text with string input."""
result = _as_answer_text("test string")
assert result == "test string"
@pytest.mark.asyncio
async def test_as_answer_text_with_dict():
"""Test _as_answer_text with dictionary input."""
test_dict = {"key": "value", "number": 42}
result = _as_answer_text(test_dict)
assert isinstance(result, str)
assert "key" in result
assert "value" in result
@pytest.mark.asyncio
async def test_as_answer_text_with_basemodel():
"""Test _as_answer_text with Pydantic BaseModel input."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
test_model = TestModel(answer="test answer")
result = _as_answer_text(test_model)
assert isinstance(result, str)
assert "[Structured Response]" in result
assert "test answer" in result

View file

@ -0,0 +1,648 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
@pytest.mark.asyncio
async def test_get_triplets_success(mock_edge):
"""Test successful retrieval of triplets."""
retriever = GraphCompletionRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
) as mock_search:
triplets = await retriever.get_triplets("test query")
assert len(triplets) == 1
assert triplets[0] == mock_edge
mock_search.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_triplets_empty_results():
"""Test that empty list is returned when no triplets are found."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
):
triplets = await retriever.get_triplets("test query")
assert triplets == []
@pytest.mark.asyncio
async def test_get_triplets_top_k_parameter():
"""Test that top_k parameter is passed to brute_force_triplet_search."""
retriever = GraphCompletionRetriever(top_k=10)
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
) as mock_search:
await retriever.get_triplets("test query")
call_kwargs = mock_search.call_args[1]
assert call_kwargs["top_k"] == 10
@pytest.mark.asyncio
async def test_get_context_success(mock_edge):
"""Test successful retrieval of context."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
):
context = await retriever.get_context("test query")
assert isinstance(context, list)
assert len(context) == 1
assert context[0] == mock_edge
@pytest.mark.asyncio
async def test_get_context_empty_results():
"""Test that empty list is returned when no context is found."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
),
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_get_context_empty_graph():
"""Test that empty list is returned when graph is empty."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=True)
with patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_resolve_edges_to_text(mock_edge):
"""Test resolve_edges_to_text method."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved text",
) as mock_resolve:
result = await retriever.resolve_edges_to_text([mock_edge])
assert result == "Resolved text"
mock_resolve.assert_awaited_once_with([mock_edge])
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionRetriever initialization with defaults."""
retriever = GraphCompletionRetriever()
assert retriever.top_k == 5
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.node_type is None
assert retriever.node_name is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionRetriever initialization with custom parameters."""
retriever = GraphCompletionRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
node_type=str,
node_name=["node1"],
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=5.0,
)
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.node_type is str
assert retriever.node_name == ["node1"]
assert retriever.save_interaction is True
assert retriever.wide_search_top_k == 200
assert retriever.triplet_distance_penalty == 5.0
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test GraphCompletionRetriever initialization with None top_k."""
retriever = GraphCompletionRetriever(top_k=None)
assert retriever.top_k == 5 # None defaults to 5
@pytest.mark.asyncio
async def test_convert_retrieved_objects_to_context(mock_edge):
"""Test convert_retrieved_objects_to_context method."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved text",
) as mock_resolve:
result = await retriever.convert_retrieved_objects_to_context([mock_edge])
assert result == "Resolved text"
mock_resolve.assert_awaited_once_with([mock_edge])
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[mock_edge])
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_get_completion_empty_context(mock_edge):
"""Test get_completion with empty context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_save_qa(mock_edge):
"""Test save_qa method."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionRetriever()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=["uuid1", "uuid2"],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[mock_edge],
)
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_awaited_once()
@pytest.mark.asyncio
async def test_save_qa_no_triplet_ids(mock_edge):
"""Test save_qa when triplets have no extractable IDs."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionRetriever()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
return_value=None,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[mock_edge],
)
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_not_called()
@pytest.mark.asyncio
async def test_save_qa_empty_triplets():
"""Test save_qa with empty triplets list."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[],
)
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_not_called()
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_completion(mock_edge):
"""Test get_completion with save_interaction but no completion."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever(save_interaction=True)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=None, # No completion
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] is None
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_context(mock_edge):
"""Test get_completion with save_interaction but no context provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever(save_interaction=True)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=None)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge):
"""Test get_completion with save_interaction when all conditions are met (line 216)."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[mock_edge])
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_add_data.assert_awaited_once()

View file

@ -0,0 +1,321 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Steve Rodger"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Mike Broski"}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = CompletionRetriever(top_k=2)
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Steve Rodger\nMike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
retriever = CompletionRetriever()
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty string is returned when no chunks are found."""
mock_vector_engine.search.return_value = []
retriever = CompletionRetriever()
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == ""
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(2)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Chunk {i}"}
mock_vector_engine.search.return_value = mock_results
retriever = CompletionRetriever(top_k=2)
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Chunk 0\nChunk 1"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
@pytest.mark.asyncio
async def test_get_context_single_chunk(mock_vector_engine):
"""Test get_context with single chunk result."""
mock_result = MagicMock()
mock_result.payload = {"text": "Single chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Single chunk text"
@pytest.mark.asyncio
async def test_get_completion_without_session(mock_vector_engine):
"""Test get_completion without session caching."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_vector_engine):
"""Test get_completion with provided context."""
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_vector_engine):
"""Test get_completion with session caching enabled."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.completion_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.completion_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
"""Test get_completion with session config but no user ID."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_init_defaults():
"""Test CompletionRetriever initialization with defaults."""
retriever = CompletionRetriever()
assert retriever.user_prompt_path == "context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 1
assert retriever.system_prompt is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test CompletionRetriever initialization with custom parameters."""
retriever = CompletionRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
top_k=10,
)
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_context_missing_text_key(mock_vector_engine):
"""Test get_context handles missing text key in payload."""
mock_result = MagicMock()
mock_result.payload = {"other_key": "value"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(KeyError):
await retriever.get_context("test query")

View file

@ -0,0 +1,193 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of summary context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = SummariesRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 2
assert context[0]["text"] == "S.R."
assert context[1]["text"] == "M.B."
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty list is returned when no summaries are found."""
mock_vector_engine.search.return_value = []
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(3)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Summary {i}"}
mock_vector_engine.search.return_value = mock_results
retriever = SummariesRetriever(top_k=3)
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
@pytest.mark.asyncio
async def test_get_completion_with_context(mock_vector_engine):
"""Test get_completion returns provided context."""
retriever = SummariesRetriever()
provided_context = [{"text": "S.R."}, {"text": "M.B."}]
completion = await retriever.get_completion("test query", context=provided_context)
assert completion == provided_context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query")
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_init_defaults():
"""Test SummariesRetriever initialization with defaults."""
retriever = SummariesRetriever()
assert retriever.top_k == 5
@pytest.mark.asyncio
async def test_init_custom_top_k():
"""Test SummariesRetriever initialization with custom top_k."""
retriever = SummariesRetriever(top_k=10)
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_context_empty_payload(mock_vector_engine):
"""Test get_context handles empty payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 1
assert context[0] == {}
@pytest.mark.asyncio
async def test_get_completion_with_session_id(mock_vector_engine):
"""Test get_completion with session_id parameter."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", session_id="test_session")
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_get_completion_with_kwargs(mock_vector_engine):
"""Test get_completion accepts additional kwargs."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", extra_param="value")
assert len(completion) == 1

View file

@ -0,0 +1,705 @@
from types import SimpleNamespace
import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock
from datetime import datetime
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
from cognee.infrastructure.llm import LLMGateway
# Test TemporalRetriever initialization defaults and overrides
def test_init_defaults_and_overrides():
tr = TemporalRetriever()
assert tr.top_k == 5
assert tr.user_prompt_path == "graph_context_for_question.txt"
assert tr.system_prompt_path == "answer_simple_question.txt"
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
tr2 = TemporalRetriever(
top_k=3,
user_prompt_path="u.txt",
system_prompt_path="s.txt",
time_extraction_prompt_path="t.txt",
)
assert tr2.top_k == 3
assert tr2.user_prompt_path == "u.txt"
assert tr2.system_prompt_path == "s.txt"
assert tr2.time_extraction_prompt_path == "t.txt"
# Test descriptions_to_string with basic and empty results
def test_descriptions_to_string_basic_and_empty():
tr = TemporalRetriever()
results = [
{"description": " First "},
{"nope": "no description"},
{"description": "Second"},
{"description": ""},
{"description": " Third line "},
]
s = tr.descriptions_to_string(results)
assert s == "First\n#####################\nSecond\n#####################\nThird line"
assert tr.descriptions_to_string([]) == ""
# Test filter_top_k_events sorts and limits correctly
@pytest.mark.asyncio
async def test_filter_top_k_events_sorts_and_limits():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3 - not in vector results"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["e2", "e1"]
assert all("score" in e for e in top)
assert top[0]["score"] == 0.10
assert top[1]["score"] == 0.20
# Test filter_top_k_events handles unknown ids as infinite scores
@pytest.mark.asyncio
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "known1", "description": "Known 1"},
{"id": "unknown", "description": "Unknown"},
{"id": "known2", "description": "Known 2"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["known2", "known1"]
assert all(e["score"] != float("inf") for e in top)
# Test descriptions_to_string with unicode and newlines
def test_descriptions_to_string_unicode_and_newlines():
tr = TemporalRetriever()
results = [
{"description": "Line A\nwith newline"},
{"description": "This is a description"},
]
s = tr.descriptions_to_string(results)
assert "Line A\nwith newline" in s
assert "This is a description" in s
assert s.count("#####################") == 1
# Test filter_top_k_events when top_k is larger than available events
@pytest.mark.asyncio
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
# Test filter_top_k_events when scored_results is empty
@pytest.mark.asyncio
async def test_filter_top_k_events_handles_empty_scored_results():
tr = TemporalRetriever(top_k=2)
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
scored_results = []
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["x", "y"]
assert all(e["score"] == float("inf") for e in out)
# Test filter_top_k_events error handling for missing structure
@pytest.mark.asyncio
async def test_filter_top_k_events_error_handling():
tr = TemporalRetriever(top_k=2)
with pytest.raises((KeyError, TypeError)):
await tr.filter_top_k_events([{}], [])
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.collect_time_ids = AsyncMock()
engine.collect_events = AsyncMock()
return engine
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.embedding_engine = AsyncMock()
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
"""Test get_context when time range is extracted from query."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
{"id": "e2", "description": "Event 2"},
]
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened in 2024?")
assert isinstance(context, str)
assert len(context) > 0
assert "Event" in context
@pytest.mark.asyncio
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
"""Test get_context falls back to triplets when no time is extracted."""
retriever = TemporalRetriever()
with (
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets,
patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve,
):
async def mock_extract_time(query):
return None, None
retriever.extract_time_from_query = mock_extract_time
context = await retriever.get_context("test query")
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_context_no_events_found(mock_graph_engine):
"""Test get_context falls back to triplets when no events are found."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = []
with (
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets,
patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve,
):
async def mock_extract_time(query):
return "2024-01-01", "2024-12-31"
retriever.extract_time_from_query = mock_extract_time
context = await retriever.get_context("test query")
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_from."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened after 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_to."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened before 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context():
"""Test get_completion uses provided context."""
retriever = TemporalRetriever()
with (
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session caching enabled."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.temporal_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"What happened in 2024?", session_id="test_session"
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session config but no user ID."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
"""Test get_completion when get_context returns empty string."""
retriever = TemporalRetriever()
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
) as mock_get_vector,
patch.object(retriever, "filter_top_k_events", return_value=[]),
):
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
mock_get_vector.return_value = mock_vector_engine
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": ""},
]
}
]
with pytest.raises((UnboundLocalError, NameError)):
await retriever.get_completion("test query")
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"What happened in 2024?", response_model=TestModel
)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_extract_time_from_query_relative_path():
"""Test extract_time_from_query with relative prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_absolute_path():
"""Test extract_time_from_query with absolute prompt path."""
retriever = TemporalRetriever(
time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
)
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
patch(
"cognee.modules.retrieval.temporal_retriever.os.path.dirname",
return_value="/absolute/path/to",
),
patch(
"cognee.modules.retrieval.temporal_retriever.os.path.basename",
return_value="extract_query_time.txt",
),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_with_none_values():
"""Test extract_time_from_query when interval has None values."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_interval = QueryInterval(starts_at=None, ends_at=None)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened?")
assert time_from is None
assert time_to is None

View file

@ -0,0 +1,817 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
get_memory_fragment,
format_triplets,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_empty_query():
"""Test that empty query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query="")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query=None)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_negative_top_k():
"""Test that negative top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=-1)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_zero_top_k():
"""Test that zero top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=0)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_global_search():
"""Test that wide_search_limit is applied for global search (node_name=None)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=None, # Global search
wide_search_top_k=75,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 75
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=["Node1"],
wide_search_top_k=50,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_default():
"""Test that wide_search_top_k defaults to 100."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", node_name=None)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 100
@pytest.mark.asyncio
async def test_brute_force_triplet_search_default_collections():
"""Test that default collections are used when none provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test")
expected_collections = [
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
"EdgeType_relationship_name",
]
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == expected_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_custom_collections():
"""Test that custom collections are used when provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
custom_collections = ["CustomCol1", "CustomCol2"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=custom_collections)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_always_includes_edge_collection():
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
collections_without_edge = ["Entity_name", "TextSummary_text"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=collections_without_edge)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert "EdgeType_relationship_name" in call_collections
assert set(call_collections) == set(collections_without_edge) | {
"EdgeType_relationship_name"
}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_all_collections_empty():
"""Test that empty list is returned when all collections return no results."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
results = await brute_force_triplet_search(query="test")
assert results == []
# Tests for query embedding
@pytest.mark.asyncio
async def test_brute_force_triplet_search_embeds_query():
"""Test that query is embedded before searching."""
query_text = "test query"
expected_vector = [0.1, 0.2, 0.3]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query=query_text)
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
for call in mock_vector_engine.search.call_args_list:
assert call[1]["query_vector"] == expected_vector
@pytest.mark.asyncio
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
"""Test that node IDs are extracted from search results for global search."""
scored_results = [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
MockScoredResult("node3", 0.92),
]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=scored_results)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_reuses_provided_fragment():
"""Test that provided memory fragment is reused instead of creating new one."""
provided_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
) as mock_get_fragment,
):
await brute_force_triplet_search(
query="test",
memory_fragment=provided_fragment,
node_name=["node"],
)
mock_get_fragment.assert_not_called()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
"""Test that memory fragment is created when not provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test", node_name=["node"])
mock_get_fragment.assert_called_once()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
"""Test that custom top_k is passed to importance calculation."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
mock_graph_engine = AsyncMock()
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
mock_fragment = MagicMock(spec=CogneeGraph)
mock_fragment.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found")
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
return_value=mock_fragment,
),
):
result = await get_memory_fragment()
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
assert result == mock_fragment
mock_fragment.project_graph_from_db.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_error():
"""Test that get_memory_fragment returns empty graph on generic error."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_brute_force_triplet_search_deduplicates_node_ids():
"""Test that duplicate node IDs across collections are deduplicated."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
]
elif collection_name == "TextSummary_text":
return [
MockScoredResult("node1", 0.90),
MockScoredResult("node3", 0.92),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
@pytest.mark.asyncio
async def test_brute_force_triplet_search_excludes_edge_collection():
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "EdgeType_relationship_name":
return [MockScoredResult("edge1", 0.88)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(
query="test",
node_name=None,
collections=["Entity_name", "EdgeType_relationship_name"],
)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
@pytest.mark.asyncio
async def test_brute_force_triplet_search_skips_nodes_without_ids():
"""Test that nodes without ID attribute are skipped."""
class ScoredResultNoId:
"""Mock result without id attribute."""
def __init__(self, score):
self.score = score
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
ScoredResultNoId(0.90),
MockScoredResult("node2", 0.87),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_handles_tuple_results():
"""Test that both list and tuple results are handled correctly."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return (
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
)
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_mixed_empty_collections():
"""Test ID extraction with mixed empty and non-empty collections."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "TextSummary_text":
return []
elif collection_name == "EntityType_name":
return [MockScoredResult("node2", 0.92)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
def test_format_triplets():
"""Test format_triplets function."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "connects" in result
def test_format_triplets_with_none_values():
"""Test format_triplets filters out None values."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "None" not in result or result.count("None") == 0
def test_format_triplets_with_nested_dict():
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
mock_edge.attributes = {"relationship_name": "relates_to"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
@pytest.mark.asyncio
async def test_brute_force_triplet_search_vector_engine_init_error():
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
) as mock_get_vector_engine,
):
mock_get_vector_engine.side_effect = Exception("Initialization error")
with pytest.raises(RuntimeError, match="Initialization error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_error():
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(
side_effect=[
CollectionNotFoundError("Collection not found"),
[],
[],
]
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=CogneeGraph(),
),
):
result = await brute_force_triplet_search(
query="test query", collections=["missing_collection", "existing_collection"]
)
assert result == []
@pytest.mark.asyncio
async def test_brute_force_triplet_search_generic_exception():
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
):
with pytest.raises(Exception, match="Generic error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test query", node_name=["Node1"])
assert mock_get_fragment.called
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
assert call_kwargs.get("relevant_ids_to_filter") is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
result = await brute_force_triplet_search(query="test query")
assert result == []

View file

@ -0,0 +1,343 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from typing import Type
class TestGenerateCompletion:
@pytest.mark.asyncio
async def test_generate_completion_with_system_prompt(self):
"""Test generate_completion with provided system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
system_prompt="Custom system prompt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="Custom system prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_without_system_prompt(self):
"""Test generate_completion reads system_prompt from file when not provided."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="System prompt from file",
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_conversation_history(self):
"""Test generate_completion includes conversation_history in system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
)
assert result == mock_llm_response
expected_system_prompt = (
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
+ "\nTASK:"
+ "System prompt from file"
)
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt=expected_system_prompt,
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
"""Test generate_completion includes conversation_history with custom system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
system_prompt="Custom system prompt",
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
)
assert result == mock_llm_response
expected_system_prompt = (
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
+ "\nTASK:"
+ "Custom system prompt"
)
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt=expected_system_prompt,
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_response_model(self):
"""Test generate_completion with custom response_model."""
mock_response_model = MagicMock()
mock_llm_response = {"answer": "Generated answer"}
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
response_model=mock_response_model,
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="System prompt from file",
response_model=mock_response_model,
)
@pytest.mark.asyncio
async def test_generate_completion_render_prompt_args(self):
"""Test generate_completion passes correct args to render_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
) as mock_render,
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
),
):
from cognee.modules.retrieval.utils.completion import generate_completion
await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
)
mock_render.assert_called_once_with(
"user_prompt.txt",
{"question": "What is AI?", "context": "AI is artificial intelligence"},
)
class TestSummarizeText:
@pytest.mark.asyncio
async def test_summarize_text_with_system_prompt(self):
"""Test summarize_text with provided system_prompt."""
mock_llm_response = "Summary text"
with patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm:
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="summarize_search_results.txt",
system_prompt="Custom summary prompt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Custom summary prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_without_system_prompt(self):
"""Test summarize_text reads system_prompt from file when not provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="summarize_search_results.txt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="System prompt from file",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_default_prompt_path(self):
"""Test summarize_text uses default prompt path when not provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="Default system prompt",
) as mock_read,
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(text="Long text to summarize")
assert result == mock_llm_response
mock_read.assert_called_once_with("summarize_search_results.txt")
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Default system prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_custom_prompt_path(self):
"""Test summarize_text uses custom prompt path when provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="Custom system prompt",
) as mock_read,
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="custom_prompt.txt",
)
assert result == mock_llm_response
mock_read.assert_called_once_with("custom_prompt.txt")
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Custom system prompt",
response_model=str,
)

View file

@ -0,0 +1,157 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
class TestGraphSummaryCompletionRetriever:
@pytest.mark.asyncio
async def test_init_defaults(self):
"""Test GraphSummaryCompletionRetriever initialization with defaults."""
retriever = GraphSummaryCompletionRetriever()
assert retriever.summarize_prompt_path == "summarize_search_results.txt"
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 5
assert retriever.save_interaction is False
@pytest.mark.asyncio
async def test_init_custom_params(self):
"""Test GraphSummaryCompletionRetriever initialization with custom parameters."""
retriever = GraphSummaryCompletionRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
summarize_prompt_path="custom_summarize.txt",
system_prompt="Custom system prompt",
top_k=10,
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=2.5,
)
assert retriever.summarize_prompt_path == "custom_summarize.txt"
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.top_k == 10
assert retriever.save_interaction is True
@pytest.mark.asyncio
async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge):
"""Test resolve_edges_to_text calls super method and then summarizes."""
retriever = GraphSummaryCompletionRetriever(
summarize_prompt_path="custom_summarize.txt",
system_prompt="Custom system prompt",
)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Resolved edges text",
) as mock_super_resolve,
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Summarized text",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([mock_edge])
assert result == "Summarized text"
mock_super_resolve.assert_awaited_once_with([mock_edge])
mock_summarize.assert_awaited_once_with(
"Resolved edges text",
"custom_summarize.txt",
"Custom system prompt",
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge):
"""Test resolve_edges_to_text uses None for system_prompt when not provided."""
retriever = GraphSummaryCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Resolved edges text",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Summarized text",
) as mock_summarize,
):
await retriever.resolve_edges_to_text([mock_edge])
mock_summarize.assert_awaited_once_with(
"Resolved edges text",
"summarize_search_results.txt",
None,
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_empty_edges(self):
"""Test resolve_edges_to_text handles empty edges list."""
retriever = GraphSummaryCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Empty summary",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([])
assert result == "Empty summary"
mock_summarize.assert_awaited_once_with(
"",
"summarize_search_results.txt",
None,
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge):
"""Test resolve_edges_to_text handles multiple edges."""
retriever = GraphSummaryCompletionRetriever()
mock_edge2 = MagicMock(spec=Edge)
mock_edge3 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Multiple edges resolved text",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Multiple edges summarized",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3])
assert result == "Multiple edges summarized"
mock_summarize.assert_awaited_once_with(
"Multiple edges resolved text",
"summarize_search_results.txt",
None,
)

View file

@ -0,0 +1,312 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID, NAMESPACE_OID, uuid5
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
from cognee.modules.engine.models import NodeSet
@pytest.fixture
def mock_feedback_evaluation():
"""Create a mock feedback evaluation."""
evaluation = MagicMock(spec=UserFeedbackEvaluation)
evaluation.evaluation = MagicMock()
evaluation.evaluation.value = "positive"
evaluation.score = 4.5
return evaluation
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
engine.add_edges = AsyncMock()
engine.apply_feedback_weight = AsyncMock()
return engine
class TestUserQAFeedback:
@pytest.mark.asyncio
async def test_init_default(self):
"""Test UserQAFeedback initialization with default last_k."""
retriever = UserQAFeedback()
assert retriever.last_k == 1
@pytest.mark.asyncio
async def test_init_custom_last_k(self):
"""Test UserQAFeedback initialization with custom last_k."""
retriever = UserQAFeedback(last_k=5)
assert retriever.last_k == 5
@pytest.mark.asyncio
async def test_add_feedback_success_with_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback with relationships."""
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
return_value=[interaction_id_1, interaction_id_2]
)
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=2)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_awaited_once()
mock_index_edges.assert_awaited_once()
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
# Verify add_edges was called with correct relationships
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 2
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
assert call_args[0][1] == UUID(interaction_id_1)
assert call_args[0][2] == "gives_feedback_to"
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
assert call_args[0][3]["ontology_valid"] is False
# Verify apply_feedback_weight was called with correct node IDs
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
assert len(weight_call_args) == 2
assert interaction_id_1 in weight_call_args
assert interaction_id_2 in weight_call_args
@pytest.mark.asyncio
async def test_add_feedback_success_no_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback without relationships."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=1)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
# Should not call add_edges or index_graph_edges when no relationships
mock_graph_engine.add_edges.assert_not_awaited()
mock_index_edges.assert_not_awaited()
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
@pytest.mark.asyncio
async def test_add_feedback_creates_correct_feedback_node(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This was a negative experience"
mock_feedback_evaluation.evaluation.value = "negative"
mock_feedback_evaluation.score = -3.0
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
# Verify add_data_points was called with correct CogneeUserFeedback
call_args = mock_add_data.call_args[1]["data_points"]
assert len(call_args) == 1
feedback_node = call_args[0]
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
assert feedback_node.feedback == feedback_text
assert feedback_node.sentiment == "negative"
assert feedback_node.score == -3.0
assert isinstance(feedback_node.belongs_to_set, NodeSet)
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
@pytest.mark.asyncio
async def test_add_feedback_calls_llm_with_correct_prompt(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Great answer!"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
) as mock_llm,
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_llm.assert_awaited_once()
call_kwargs = mock_llm.call_args[1]
assert call_kwargs["text_input"] == feedback_text
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
assert call_kwargs["response_model"] == UserFeedbackEvaluation
@pytest.mark.asyncio
async def test_add_feedback_uses_last_k_parameter(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback(last_k=5)
await retriever.add_feedback(feedback_text)
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
@pytest.mark.asyncio
async def test_add_feedback_with_single_interaction(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback with single interaction ID."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
# Should create relationship for the interaction
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 1
assert call_args[0][1] == UUID(interaction_id)
@pytest.mark.asyncio
async def test_add_feedback_applies_weight_correctly(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback applies feedback weight with correct score."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
mock_feedback_evaluation.score = 4.5
feedback_text = "Positive feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
node_ids=[interaction_id], weight=4.5
)

View file

@ -0,0 +1,329 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.has_collection = AsyncMock(return_value=True)
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of triplet context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Alice knows Bob"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Bob works at Tech Corp"}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = TripletRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
@pytest.mark.asyncio
async def test_get_context_no_collection(mock_vector_engine):
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
mock_vector_engine.has_collection.return_value = False
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty string is returned when no triplets are found."""
mock_vector_engine.search.return_value = []
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == ""
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_payload_text(mock_vector_engine):
"""Test get_context handles missing text in payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(KeyError):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_single_triplet(mock_vector_engine):
"""Test get_context with single triplet result."""
mock_result = MagicMock()
mock_result.payload = {"text": "Single triplet"}
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Single triplet"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test TripletRetriever initialization with defaults."""
retriever = TripletRetriever()
assert retriever.user_prompt_path == "context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 5 # Default is 5
assert retriever.system_prompt is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test TripletRetriever initialization with custom parameters."""
retriever = TripletRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
top_k=10,
)
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_vector_engine):
"""Test get_completion uses provided context."""
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_vector_engine):
"""Test get_completion with session caching enabled."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.triplet_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.triplet_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
"""Test get_completion with session config but no user ID."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test TripletRetriever initialization with None top_k."""
retriever = TripletRetriever(top_k=None)
assert retriever.top_k == 5