chore: adds new Unit tests for retrievers

This commit is contained in:
hajdul88 2025-12-12 14:44:41 +01:00
parent 127d9860df
commit fd23c75c09
14 changed files with 4454 additions and 1205 deletions

View file

@ -1,201 +1,183 @@
import os
import pytest
import pathlib
from typing import List
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
class TestChunksRetriever:
@pytest.mark.asyncio
async def test_chunk_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of chunk context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1}
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
retriever = ChunksRetriever(top_k=5)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
entities = [chunk1, chunk2, chunk3]
assert len(context) == 2
assert context[0]["text"] == "Steve Rodger"
assert context[1]["text"] == "Mike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
await add_data_points(entities)
retriever = ChunksRetriever()
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
context = await retriever.get_context("Mike")
retriever = ChunksRetriever()
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_chunk_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty list is returned when no chunks are found."""
mock_vector_engine.search.return_value = []
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
retriever = ChunksRetriever()
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
assert context == []
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(3)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Chunk {i}"}
await add_data_points(entities)
mock_vector_engine.search.return_value = mock_results
retriever = ChunksRetriever(top_k=20)
retriever = ChunksRetriever(top_k=3)
context = await retriever.get_context("Christina")
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.mark.asyncio
async def test_get_completion_with_context(mock_vector_engine):
"""Test get_completion returns provided context."""
retriever = ChunksRetriever()
retriever = ChunksRetriever()
provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}]
completion = await retriever.get_completion("test query", context=provided_context)
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
assert completion == provided_context
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "Steve Rodger"}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query")
assert len(completion) == 1
assert completion[0]["text"] == "Steve Rodger"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test ChunksRetriever initialization with defaults."""
retriever = ChunksRetriever()
assert retriever.top_k == 5
@pytest.mark.asyncio
async def test_init_custom_top_k():
"""Test ChunksRetriever initialization with custom top_k."""
retriever = ChunksRetriever(top_k=10)
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test ChunksRetriever initialization with None top_k."""
retriever = ChunksRetriever(top_k=None)
assert retriever.top_k is None
@pytest.mark.asyncio
async def test_get_context_empty_payload(mock_vector_engine):
"""Test get_context handles empty payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 1
assert context[0] == {}
@pytest.mark.asyncio
async def test_get_completion_with_session_id(mock_vector_engine):
"""Test get_completion with session_id parameter."""
mock_result = MagicMock()
mock_result.payload = {"text": "Steve Rodger"}
mock_vector_engine.search.return_value = [mock_result]
retriever = ChunksRetriever()
with patch(
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", session_id="test_session")
assert len(completion) == 1
assert completion[0]["text"] == "Steve Rodger"

View file

@ -152,3 +152,341 @@ class TestConversationHistoryUtils:
assert result is True
call_kwargs = mock_cache.add_qa.call_args.kwargs
assert call_kwargs["session_id"] == "default_session"
@pytest.mark.asyncio
async def test_save_conversation_history_no_user_id(self):
"""Test save_conversation_history returns False when user_id is None."""
session_user.set(None)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_caching_disabled(self):
"""Test save_conversation_history returns False when caching is disabled."""
user = create_mock_user()
session_user.set(user)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = False
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_cache_engine_none(self):
"""Test save_conversation_history returns False when cache_engine is None."""
user = create_mock_user()
session_user.set(user)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=None):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_cache_connection_error(self):
"""Test save_conversation_history handles CacheConnectionError gracefully."""
user = create_mock_user()
session_user.set(user)
from cognee.infrastructure.databases.exceptions import CacheConnectionError
mock_cache = create_mock_cache_engine([])
mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_save_conversation_history_generic_exception(self):
"""Test save_conversation_history handles generic exceptions gracefully."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
)
assert result is False
@pytest.mark.asyncio
async def test_get_conversation_history_no_user_id(self):
"""Test get_conversation_history returns empty string when user_id is None."""
session_user.set(None)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_caching_disabled(self):
"""Test get_conversation_history returns empty string when caching is disabled."""
user = create_mock_user()
session_user.set(user)
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = False
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_default_session(self):
"""Test get_conversation_history uses 'default_session' when session_id is None."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
await get_conversation_history(session_id=None)
mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session")
@pytest.mark.asyncio
async def test_get_conversation_history_cache_engine_none(self):
"""Test get_conversation_history returns empty string when cache_engine is None."""
user = create_mock_user()
session_user.set(user)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=None):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_cache_connection_error(self):
"""Test get_conversation_history handles CacheConnectionError gracefully."""
user = create_mock_user()
session_user.set(user)
from cognee.infrastructure.databases.exceptions import CacheConnectionError
mock_cache = create_mock_cache_engine([])
mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_generic_exception(self):
"""Test get_conversation_history handles generic exceptions gracefully."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_missing_keys(self):
"""Test get_conversation_history handles missing keys in history entries."""
user = create_mock_user()
session_user.set(user)
mock_history = [
{
"time": "2024-01-15 10:30:45",
"question": "What is AI?",
},
{
"context": "AI is artificial intelligence",
"answer": "AI stands for Artificial Intelligence",
},
{},
]
mock_cache = create_mock_cache_engine(mock_history)
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert "Previous conversation:" in result
assert "[2024-01-15 10:30:45]" in result
assert "QUESTION: What is AI?" in result
assert "Unknown time" in result
assert "CONTEXT: AI is artificial intelligence" in result
assert "ANSWER: AI stands for Artificial Intelligence" in result

View file

@ -1,177 +1,469 @@
import os
import pytest
import pathlib
from typing import Optional, Union
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class TestGraphCompletionWithContextExtensionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_simple",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_simple",
)
cognee.config.data_root_directory(data_directory_path)
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
@pytest.mark.asyncio
async def test_get_triplets_inherited(mock_edge):
"""Test that get_triplets is inherited from parent class."""
retriever = GraphCompletionContextExtensionRetriever()
class Person(DataPoint):
name: str
works_for: Company
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
):
triplets = await retriever.get_triplets("test query")
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
assert len(triplets) == 1
assert triplets[0] == mock_edge
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionContextExtensionRetriever initialization with defaults."""
retriever = GraphCompletionContextExtensionRetriever()
retriever = GraphCompletionContextExtensionRetriever()
assert retriever.top_k == 5
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionContextExtensionRetriever initialization with custom parameters."""
retriever = GraphCompletionContextExtensionRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
node_type=str,
node_name=["node1"],
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=5.0,
)
answer = await retriever.get_completion("Who works at Canva?")
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.node_type is str
assert retriever.node_name == ["node1"]
assert retriever.save_interaction is True
assert retriever.wide_search_top_k == 200
assert retriever.triplet_distance_penalty == 5.0
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context_extension_rounds=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=1
)
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_complex",
)
cognee.config.data_root_directory(data_directory_path)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
@pytest.mark.asyncio
async def test_get_completion_context_extension_rounds(mock_edge):
"""Test get_completion with multiple context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
class Car(DataPoint):
brand: str
model: str
year: int
retriever = GraphCompletionContextExtensionRetriever()
class Location(DataPoint):
country: str
city: str
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
class Home(DataPoint):
location: Location
rooms: int
sqm: int
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=["Resolved context", "Extended context"], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Query for extension, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
completion = await retriever.get_completion("test query", context_extension_rounds=1)
company1 = Company(name="Figma")
company2 = Company(name="Canva")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
@pytest.mark.asyncio
async def test_get_completion_context_extension_stops_early(mock_edge):
"""Test get_completion stops early when no new triplets found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
person3 = Person(name="Jason Statham", works_for=company1)
retriever = GraphCompletionContextExtensionRetriever()
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context = await resolve_edges_to_text(
await retriever.get_context("Who works at Figma and drives Tesla?")
# When get_context returns same triplets, the loop should stop early
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=4
)
print(context)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
retriever = GraphCompletionContextExtensionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"test query", session_id="test_session", context_extension_rounds=1
)
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction(mock_edge):
"""Test get_completion with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", context=[mock_edge], context_extension_rounds=1
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
assert isinstance(completion, list)
assert len(completion) == 1
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
TestModel(answer="Test answer"),
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", response_model=TestModel, context_extension_rounds=1
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
retriever = GraphCompletionContextExtensionRetriever()
await setup()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_edge):
"""Test get_completion with session config but no user ID."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
retriever = GraphCompletionContextExtensionRetriever()
answer = await retriever.get_completion("Who works at Figma?")
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
completion = await retriever.get_completion("test query", context_extension_rounds=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_zero_extension_rounds(mock_edge):
"""Test get_completion with zero context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context_extension_rounds=0)
assert isinstance(completion, list)
assert len(completion) == 1

View file

@ -1,170 +1,688 @@
import os
import pytest
import pathlib
from typing import Optional, Union
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_cot_retriever import (
GraphCompletionCotRetriever,
_as_answer_text,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.infrastructure.llm.LLMGateway import LLMGateway
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
@pytest.mark.asyncio
async def test_get_triplets_inherited(mock_edge):
"""Test that get_triplets is inherited from parent class."""
retriever = GraphCompletionCotRetriever()
class Person(DataPoint):
name: str
works_for: Company
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
):
triplets = await retriever.get_triplets("test query")
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
assert len(triplets) == 1
assert triplets[0] == mock_edge
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
retriever = GraphCompletionCotRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
validation_user_prompt_path="custom_validation_user.txt",
validation_system_prompt_path="custom_validation_system.txt",
followup_system_prompt_path="custom_followup_system.txt",
followup_user_prompt_path="custom_followup_user.txt",
)
retriever = GraphCompletionCotRetriever()
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionCotRetriever initialization with defaults."""
retriever = GraphCompletionCotRetriever()
answer = await retriever.get_completion("Who works at Canva?")
assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_run_cot_completion_round_zero_with_context(mock_edge):
"""Test _run_cot_completion round 0 with provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
max_iter=1,
)
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_cot_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
@pytest.mark.asyncio
async def test_run_cot_completion_round_zero_without_context(mock_edge):
"""Test _run_cot_completion round 0 without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
class Car(DataPoint):
brand: str
model: str
year: int
retriever = GraphCompletionCotRetriever()
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=None,
max_iter=1,
)
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
@pytest.mark.asyncio
async def test_run_cot_completion_multiple_rounds(mock_edge):
"""Test _run_cot_completion with multiple rounds."""
retriever = GraphCompletionCotRetriever()
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=[
"validation_result",
"followup_question",
"validation_result2",
"followup_question2",
],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
max_iter=2,
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert len(triplets) >= 1
@pytest.mark.asyncio
async def test_run_cot_completion_with_conversation_history(mock_edge):
"""Test _run_cot_completion with conversation history."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
) as mock_generate,
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
conversation_history="Previous conversation",
max_iter=1,
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
assert completion == "Generated answer"
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") == "Previous conversation"
retriever = GraphCompletionCotRetriever()
await setup()
@pytest.mark.asyncio
async def test_run_cot_completion_with_response_model(mock_edge):
"""Test _run_cot_completion with custom response model."""
from pydantic import BaseModel
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
class TestModel(BaseModel):
answer: str
answer = await retriever.get_completion("Who works at Figma?")
retriever = GraphCompletionCotRetriever()
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
response_model=TestModel,
max_iter=1,
)
assert isinstance(completion, TestModel)
assert completion.answer == "Test answer"
@pytest.mark.asyncio
async def test_run_cot_completion_empty_conversation_history(mock_edge):
"""Test _run_cot_completion with empty conversation history."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
) as mock_generate,
):
completion, context_text, triplets = await retriever._run_cot_completion(
query="test query",
context=[mock_edge],
conversation_history="",
max_iter=1,
)
assert completion == "Generated answer"
# Verify conversation_history was passed as None when empty
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") is None
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"test query", session_id="test_session", max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction(mock_edge):
"""Test get_completion with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionCotRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# Pass context so save_interaction condition is met
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"test query", response_model=TestModel, max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_edge):
"""Test get_completion with session config but no user ID."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query", max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_context(mock_edge):
"""Test get_completion with save_interaction but no context provided."""
retriever = GraphCompletionCotRetriever(save_interaction=True)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=["validation_result", "followup_question"],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=None, max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_as_answer_text_with_typeerror():
"""Test _as_answer_text handles TypeError when json.dumps fails."""
non_serializable = {1, 2, 3}
result = _as_answer_text(non_serializable)
assert isinstance(result, str)
assert result == str(non_serializable)
@pytest.mark.asyncio
async def test_as_answer_text_with_string():
"""Test _as_answer_text with string input."""
result = _as_answer_text("test string")
assert result == "test string"
@pytest.mark.asyncio
async def test_as_answer_text_with_dict():
"""Test _as_answer_text with dictionary input."""
test_dict = {"key": "value", "number": 42}
result = _as_answer_text(test_dict)
assert isinstance(result, str)
assert "key" in result
assert "value" in result
@pytest.mark.asyncio
async def test_as_answer_text_with_basemodel():
"""Test _as_answer_text with Pydantic BaseModel input."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
test_model = TestModel(answer="test answer")
result = _as_answer_text(test_model)
assert isinstance(result, str)
assert "[Structured Response]" in result
assert "test answer" in result

View file

@ -1,223 +1,648 @@
import os
import pytest
import pathlib
from typing import Optional, Union
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
@pytest.mark.asyncio
async def test_get_triplets_success(mock_edge):
"""Test successful retrieval of triplets."""
retriever = GraphCompletionRetriever(top_k=5)
class Person(DataPoint):
name: str
description: str
works_for: Company
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
) as mock_search:
triplets = await retriever.get_triplets("test query")
company1 = Company(name="Figma", description="Figma is a company")
company2 = Company(name="Canva", description="Canvas is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
person2 = Person(
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
)
person3 = Person(
name="Jason Statham",
description="This is description about Jason Statham",
works_for=company1,
)
person4 = Person(
name="Mike Broski",
description="This is description about Mike Broski",
works_for=company2,
)
person5 = Person(
name="Christina Mayer",
description="This is description about Christina Mayer",
works_for=company2,
assert len(triplets) == 1
assert triplets[0] == mock_edge
mock_search.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_triplets_empty_results():
"""Test that empty list is returned when no triplets are found."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
):
triplets = await retriever.get_triplets("test query")
assert triplets == []
@pytest.mark.asyncio
async def test_get_triplets_top_k_parameter():
"""Test that top_k parameter is passed to brute_force_triplet_search."""
retriever = GraphCompletionRetriever(top_k=10)
with patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
) as mock_search:
await retriever.get_triplets("test query")
call_kwargs = mock_search.call_args[1]
assert call_kwargs["top_k"] == 10
@pytest.mark.asyncio
async def test_get_context_success(mock_edge):
"""Test successful retrieval of context."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
):
context = await retriever.get_context("test query")
assert isinstance(context, list)
assert len(context) == 1
assert context[0] == mock_edge
@pytest.mark.asyncio
async def test_get_context_empty_results():
"""Test that empty list is returned when no context is found."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
),
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_get_context_empty_graph():
"""Test that empty list is returned when graph is empty."""
retriever = GraphCompletionRetriever()
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=True)
with patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
):
context = await retriever.get_context("test query")
assert context == []
@pytest.mark.asyncio
async def test_resolve_edges_to_text(mock_edge):
"""Test resolve_edges_to_text method."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved text",
) as mock_resolve:
result = await retriever.resolve_edges_to_text([mock_edge])
assert result == "Resolved text"
mock_resolve.assert_awaited_once_with([mock_edge])
@pytest.mark.asyncio
async def test_init_defaults():
"""Test GraphCompletionRetriever initialization with defaults."""
retriever = GraphCompletionRetriever()
assert retriever.top_k == 5
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.node_type is None
assert retriever.node_name is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test GraphCompletionRetriever initialization with custom parameters."""
retriever = GraphCompletionRetriever(
top_k=10,
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
node_type=str,
node_name=["node1"],
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=5.0,
)
assert retriever.top_k == 10
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.node_type is str
assert retriever.node_name == ["node1"]
assert retriever.save_interaction is True
assert retriever.wide_search_top_k == 200
assert retriever.triplet_distance_penalty == 5.0
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test GraphCompletionRetriever initialization with None top_k."""
retriever = GraphCompletionRetriever(top_k=None)
assert retriever.top_k == 5 # None defaults to 5
@pytest.mark.asyncio
async def test_convert_retrieved_objects_to_context(mock_edge):
"""Test convert_retrieved_objects_to_context method."""
retriever = GraphCompletionRetriever()
with patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved text",
) as mock_resolve:
result = await retriever.convert_retrieved_objects_to_context([mock_edge])
assert result == "Resolved text"
mock_resolve.assert_awaited_once_with([mock_edge])
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_edge):
"""Test get_completion uses provided context."""
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[mock_edge])
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_edge):
"""Test get_completion with session caching enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.save_conversation_history",
) as mock_save,
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
patch(
"cognee.modules.retrieval.graph_completion_retriever.session_user"
) as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_edge):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_get_completion_empty_context(mock_edge):
"""Test get_completion with empty context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_save_qa(mock_edge):
"""Test save_qa method."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionRetriever()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=["uuid1", "uuid2"],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[mock_edge],
)
entities = [company1, company2, person1, person2, person3, person4, person5]
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_awaited_once()
await add_data_points(entities)
retriever = GraphCompletionRetriever()
@pytest.mark.asyncio
async def test_save_qa_no_triplet_ids(mock_edge):
"""Test save_qa when triplets have no extractable IDs."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
retriever = GraphCompletionRetriever()
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
assert "Connections:" in context, "Missing 'Connections:' section in context"
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
# --- Nodes headers ---
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
assert "Node: Figma" in context, "Missing node header for Figma"
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
assert "Node: Canva" in context, "Missing node header for Canva"
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
# --- Node contents ---
assert (
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
in context
), "Description block for Steve Rodger altered"
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
"Description block for Figma altered"
)
assert (
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
in context
), "Description block for Ike Loma altered"
assert (
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
in context
), "Description block for Jason Statham altered"
assert (
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
in context
), "Description block for Mike Broski altered"
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
"Description block for Canva altered"
)
assert (
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
in context
), "Description block for Christina Mayer altered"
# --- Connections ---
assert "Steve Rodger --[works_for]--> Figma" in context, (
"Connection Steve Rodger→Figma missing or changed"
)
assert "Ike Loma --[works_for]--> Figma" in context, (
"Connection Ike Loma→Figma missing or changed"
)
assert "Jason Statham --[works_for]--> Figma" in context, (
"Connection Jason Statham→Figma missing or changed"
)
assert "Mike Broski --[works_for]--> Canva" in context, (
"Connection Mike Broski→Canva missing or changed"
)
assert "Christina Mayer --[works_for]--> Canva" in context, (
"Connection Christina Mayer→Canva missing or changed"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
return_value=None,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[mock_edge],
)
@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_not_called()
@pytest.mark.asyncio
async def test_save_qa_empty_triplets():
"""Test save_qa with empty triplets list."""
mock_graph_engine = AsyncMock()
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
):
await retriever.save_qa(
question="Test question",
answer="Test answer",
context="Test context",
triplets=[],
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_not_called()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_completion(mock_edge):
"""Test get_completion with save_interaction but no completion."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
class Location(DataPoint):
country: str
city: str
retriever = GraphCompletionRetriever(save_interaction=True)
class Home(DataPoint):
location: Location
rooms: int
sqm: int
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=None, # No completion
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
completion = await retriever.get_completion("test query")
company1 = Company(name="Figma")
company2 = Company(name="Canva")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] is None
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_no_context(mock_edge):
"""Test get_completion with save_interaction but no context provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
person3 = Person(name="Jason Statham", works_for=company1)
retriever = GraphCompletionRetriever(save_interaction=True)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
completion = await retriever.get_completion("test query", context=None)
entities = [company1, company2, person1, person2, person3, person4, person5]
assert isinstance(completion, list)
assert len(completion) == 1
await add_data_points(entities)
retriever = GraphCompletionRetriever(top_k=20)
@pytest.mark.asyncio
async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge):
"""Test get_completion with save_interaction when all conditions are met (line 216)."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
retriever = GraphCompletionRetriever(save_interaction=True)
print(context)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
completion = await retriever.get_completion("test query", context=[mock_edge])
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionRetriever()
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_add_data.assert_awaited_once()

View file

@ -1,205 +1,321 @@
import os
from typing import List
import pytest
import pathlib
import cognee
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Steve Rodger"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Mike Broski"}
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
retriever = CompletionRetriever(top_k=2)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
entities = [chunk1, chunk2, chunk3]
assert context == "Steve Rodger\nMike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
await add_data_points(entities)
retriever = CompletionRetriever()
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
context = await retriever.get_context("Mike")
retriever = CompletionRetriever()
assert context == "Mike Broski", "Failed to get Mike Broski"
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty string is returned when no chunks are found."""
mock_vector_engine.search.return_value = []
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
retriever = CompletionRetriever()
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
assert context == ""
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(2)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Chunk {i}"}
await add_data_points(entities)
mock_vector_engine.search.return_value = mock_results
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
retriever = CompletionRetriever(top_k=2)
context = await retriever.get_context("Christina")
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
assert context == "Chunk 0\nChunk 1"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.mark.asyncio
async def test_get_context_single_chunk(mock_vector_engine):
"""Test get_context with single chunk result."""
mock_result = MagicMock()
mock_result.payload = {"text": "Single chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
assert context == "Single chunk text"
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"
@pytest.mark.asyncio
async def test_get_completion_without_session(mock_vector_engine):
"""Test get_completion without session caching."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_vector_engine):
"""Test get_completion with provided context."""
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_vector_engine):
"""Test get_completion with session caching enabled."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.completion_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.completion_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
"""Test get_completion with session config but no user ID."""
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_result = MagicMock()
mock_result.payload = {"text": "Chunk text"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with (
patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_init_defaults():
"""Test CompletionRetriever initialization with defaults."""
retriever = CompletionRetriever()
assert retriever.user_prompt_path == "context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 1
assert retriever.system_prompt is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test CompletionRetriever initialization with custom parameters."""
retriever = CompletionRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
top_k=10,
)
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_context_missing_text_key(mock_vector_engine):
"""Test get_context handles missing text key in payload."""
mock_result = MagicMock()
mock_result.payload = {"other_key": "value"}
mock_vector_engine.search.return_value = [mock_result]
retriever = CompletionRetriever()
with patch(
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(KeyError):
await retriever.get_context("test query")

View file

@ -1,204 +0,0 @@
import asyncio
import pytest
import cognee
import pathlib
import os
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
def _assert_string_answer(answer: list[str]):
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
assert all(item.strip() for item in answer), "Items should not be empty"
def _assert_structured_answer(answer: list[TestAnswer]):
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
async def _test_get_structured_graph_completion_cot():
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion():
retriever = GraphCompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_temporal():
retriever = TemporalRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"When did Steve start working at Figma??", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_rag():
retriever = CompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Where does Steve work?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Where does Steve work?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_context_extension():
retriever = GraphCompletionContextExtensionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_entity_completion():
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
# Test with string response model (default)
string_answer = await retriever.get_completion("Who is Albert Einstein?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who is Albert Einstein?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
class TestStructuredOutputCompletion:
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -1,159 +1,193 @@
import os
import pytest
import pathlib
from unittest.mock import AsyncMock, patch, MagicMock
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of summary context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"}
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
retriever = SummariesRetriever(top_k=5)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
await add_data_points(entities)
assert len(context) == 2
assert context[0]["text"] == "S.R."
assert context[1]["text"] == "M.B."
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
retriever = SummariesRetriever(top_k=20)
context = await retriever.get_context("Christina")
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
retriever = SummariesRetriever()
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = SummariesRetriever()
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty list is returned when no summaries are found."""
mock_vector_engine.search.return_value = []
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
retriever = SummariesRetriever()
vector_engine = get_vector_engine()
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"
assert context == []
@pytest.mark.asyncio
async def test_get_context_top_k_limit(mock_vector_engine):
"""Test that top_k parameter limits the number of results."""
mock_results = [MagicMock() for _ in range(3)]
for i, result in enumerate(mock_results):
result.payload = {"text": f"Summary {i}"}
mock_vector_engine.search.return_value = mock_results
retriever = SummariesRetriever(top_k=3)
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
@pytest.mark.asyncio
async def test_get_completion_with_context(mock_vector_engine):
"""Test get_completion returns provided context."""
retriever = SummariesRetriever()
provided_context = [{"text": "S.R."}, {"text": "M.B."}]
completion = await retriever.get_completion("test query", context=provided_context)
assert completion == provided_context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query")
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_init_defaults():
"""Test SummariesRetriever initialization with defaults."""
retriever = SummariesRetriever()
assert retriever.top_k == 5
@pytest.mark.asyncio
async def test_init_custom_top_k():
"""Test SummariesRetriever initialization with custom top_k."""
retriever = SummariesRetriever(top_k=10)
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_context_empty_payload(mock_vector_engine):
"""Test get_context handles empty payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 1
assert context[0] == {}
@pytest.mark.asyncio
async def test_get_completion_with_session_id(mock_vector_engine):
"""Test get_completion with session_id parameter."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", session_id="test_session")
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_get_completion_with_kwargs(mock_vector_engine):
"""Test get_completion accepts additional kwargs."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", extra_param="value")
assert len(completion) == 1

View file

@ -1,7 +1,12 @@
from types import SimpleNamespace
import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock
from datetime import datetime
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
from cognee.infrastructure.llm import LLMGateway
# Test TemporalRetriever initialization defaults and overrides
@ -140,85 +145,561 @@ async def test_filter_top_k_events_error_handling():
await tr.filter_top_k_events([{}], [])
class _FakeRetriever(TemporalRetriever):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._calls = []
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.collect_time_ids = AsyncMock()
engine.collect_events = AsyncMock()
return engine
async def extract_time_from_query(self, query: str):
if "both" in query:
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.embedding_engine = AsyncMock()
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
"""Test get_context when time range is extracted from query."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
{"id": "e2", "description": "Event 2"},
]
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened in 2024?")
assert isinstance(context, str)
assert len(context) > 0
assert "Event" in context
@pytest.mark.asyncio
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
"""Test get_context falls back to triplets when no time is extracted."""
retriever = TemporalRetriever()
with (
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets,
patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve,
):
async def mock_extract_time(query):
return None, None
retriever.extract_time_from_query = mock_extract_time
context = await retriever.get_context("test query")
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_context_no_events_found(mock_graph_engine):
"""Test get_context falls back to triplets when no events are found."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = []
with (
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets,
patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve,
):
async def mock_extract_time(query):
return "2024-01-01", "2024-12-31"
if "from_only" in query:
return "2024-01-01", None
if "to_only" in query:
return None, "2024-12-31"
return None, None
async def get_triplets(self, query: str):
self._calls.append(("get_triplets", query))
return [{"s": "a", "p": "b", "o": "c"}]
retriever.extract_time_from_query = mock_extract_time
async def resolve_edges_to_text(self, triplets):
self._calls.append(("resolve_edges_to_text", len(triplets)))
return "edges->text"
context = await retriever.get_context("test query")
async def _fake_graph_collect_ids(self, **kwargs):
return ["e1", "e2"]
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
async def _fake_graph_collect_events(self, ids):
return [
@pytest.mark.asyncio
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_from."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened after 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_to."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened before 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context():
"""Test get_completion uses provided context."""
retriever = TemporalRetriever()
with (
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session caching enabled."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.temporal_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"What happened in 2024?", session_id="test_session"
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session config but no user ID."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
"""Test get_completion when get_context returns empty string."""
retriever = TemporalRetriever()
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
) as mock_get_vector,
patch.object(retriever, "filter_top_k_events", return_value=[]),
):
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
mock_get_vector.return_value = mock_vector_engine
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3"},
{"id": "e1", "description": ""},
]
}
]
async def _fake_vector_embed(self, texts):
assert isinstance(texts, list) and texts
return [[0.0, 1.0, 2.0]]
with pytest.raises((UnboundLocalError, NameError)):
await retriever.get_completion("test query")
async def _fake_vector_search(self, **kwargs):
return [
SimpleNamespace(payload={"id": "e2"}, score=0.05),
SimpleNamespace(payload={"id": "e1"}, score=0.10),
]
async def get_context(self, query: str):
time_from, time_to = await self.extract_time_from_query(query)
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
if not (time_from or time_to):
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
class TestModel(BaseModel):
answer: str
ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
relevant_events = await self._fake_graph_collect_events(ids)
retriever = TemporalRetriever()
_ = await self._fake_vector_embed([query])
vector_search_results = await self._fake_vector_search(
collection_name="Event_name", query_vector=[0.0], limit=0
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"What happened in 2024?", response_model=TestModel
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
# Test get_context fallback to triplets when no time is extracted
@pytest.mark.asyncio
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("no_time")
assert ctx == "edges->text"
assert tr._calls[0][0] == "get_triplets"
assert tr._calls[1][0] == "resolve_edges_to_text"
async def test_extract_time_from_query_relative_path():
"""Test extract_time_from_query with relative prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
# Test get_context when time is extracted and vector ranking is applied
@pytest.mark.asyncio
async def test_fake_get_context_with_time_filters_and_vector_ranking():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("both time")
assert ctx.startswith("E2")
assert "#####################" in ctx
assert "E1" in ctx and "E3" not in ctx
async def test_extract_time_from_query_absolute_path():
"""Test extract_time_from_query with absolute prompt path."""
retriever = TemporalRetriever(
time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
)
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
patch(
"cognee.modules.retrieval.temporal_retriever.os.path.dirname",
return_value="/absolute/path/to",
),
patch(
"cognee.modules.retrieval.temporal_retriever.os.path.basename",
return_value="extract_query_time.txt",
),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_with_none_values():
"""Test extract_time_from_query when interval has None values."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_interval = QueryInterval(starts_at=None, ends_at=None)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch(
"cognee.modules.retrieval.temporal_retriever.render_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_interval,
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened?")
assert time_from is None
assert time_to is None

View file

@ -1,12 +1,14 @@
import pytest
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
get_memory_fragment,
format_triplets,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class MockScoredResult:
@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found."""
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
mock_fragment = MagicMock(spec=CogneeGraph)
mock_fragment.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found")
)
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
return_value=mock_fragment,
),
):
fragment = await get_memory_fragment()
result = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
assert result == mock_fragment
mock_fragment.project_graph_from_db.assert_awaited_once()
@pytest.mark.asyncio
@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
def test_format_triplets():
"""Test format_triplets function."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "connects" in result
def test_format_triplets_with_none_values():
"""Test format_triplets filters out None values."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "None" not in result or result.count("None") == 0
def test_format_triplets_with_nested_dict():
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
mock_edge.attributes = {"relationship_name": "relates_to"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
@pytest.mark.asyncio
async def test_brute_force_triplet_search_vector_engine_init_error():
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
) as mock_get_vector_engine,
):
mock_get_vector_engine.side_effect = Exception("Initialization error")
with pytest.raises(RuntimeError, match="Initialization error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_error():
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(
side_effect=[
CollectionNotFoundError("Collection not found"),
[],
[],
]
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=CogneeGraph(),
),
):
result = await brute_force_triplet_search(
query="test query", collections=["missing_collection", "existing_collection"]
)
assert result == []
@pytest.mark.asyncio
async def test_brute_force_triplet_search_generic_exception():
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
):
with pytest.raises(Exception, match="Generic error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test query", node_name=["Node1"])
assert mock_get_fragment.called
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
assert call_kwargs.get("relevant_ids_to_filter") is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
result = await brute_force_triplet_search(query="test query")
assert result == []

View file

@ -0,0 +1,343 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from typing import Type
class TestGenerateCompletion:
@pytest.mark.asyncio
async def test_generate_completion_with_system_prompt(self):
"""Test generate_completion with provided system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
system_prompt="Custom system prompt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="Custom system prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_without_system_prompt(self):
"""Test generate_completion reads system_prompt from file when not provided."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="System prompt from file",
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_conversation_history(self):
"""Test generate_completion includes conversation_history in system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
)
assert result == mock_llm_response
expected_system_prompt = (
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
+ "\nTASK:"
+ "System prompt from file"
)
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt=expected_system_prompt,
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
"""Test generate_completion includes conversation_history with custom system_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
system_prompt="Custom system prompt",
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
)
assert result == mock_llm_response
expected_system_prompt = (
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
+ "\nTASK:"
+ "Custom system prompt"
)
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt=expected_system_prompt,
response_model=str,
)
@pytest.mark.asyncio
async def test_generate_completion_with_response_model(self):
"""Test generate_completion with custom response_model."""
mock_response_model = MagicMock()
mock_llm_response = {"answer": "Generated answer"}
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
),
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import generate_completion
result = await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
response_model=mock_response_model,
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="User prompt text",
system_prompt="System prompt from file",
response_model=mock_response_model,
)
@pytest.mark.asyncio
async def test_generate_completion_render_prompt_args(self):
"""Test generate_completion passes correct args to render_prompt."""
mock_llm_response = "Generated answer"
with (
patch(
"cognee.modules.retrieval.utils.completion.render_prompt",
return_value="User prompt text",
) as mock_render,
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
),
):
from cognee.modules.retrieval.utils.completion import generate_completion
await generate_completion(
query="What is AI?",
context="AI is artificial intelligence",
user_prompt_path="user_prompt.txt",
system_prompt_path="system_prompt.txt",
)
mock_render.assert_called_once_with(
"user_prompt.txt",
{"question": "What is AI?", "context": "AI is artificial intelligence"},
)
class TestSummarizeText:
@pytest.mark.asyncio
async def test_summarize_text_with_system_prompt(self):
"""Test summarize_text with provided system_prompt."""
mock_llm_response = "Summary text"
with patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm:
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="summarize_search_results.txt",
system_prompt="Custom summary prompt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Custom summary prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_without_system_prompt(self):
"""Test summarize_text reads system_prompt from file when not provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="System prompt from file",
),
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="summarize_search_results.txt",
)
assert result == mock_llm_response
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="System prompt from file",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_default_prompt_path(self):
"""Test summarize_text uses default prompt path when not provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="Default system prompt",
) as mock_read,
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(text="Long text to summarize")
assert result == mock_llm_response
mock_read.assert_called_once_with("summarize_search_results.txt")
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Default system prompt",
response_model=str,
)
@pytest.mark.asyncio
async def test_summarize_text_custom_prompt_path(self):
"""Test summarize_text uses custom prompt path when provided."""
mock_llm_response = "Summary text"
with (
patch(
"cognee.modules.retrieval.utils.completion.read_query_prompt",
return_value="Custom system prompt",
) as mock_read,
patch(
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_llm_response,
) as mock_llm,
):
from cognee.modules.retrieval.utils.completion import summarize_text
result = await summarize_text(
text="Long text to summarize",
system_prompt_path="custom_prompt.txt",
)
assert result == mock_llm_response
mock_read.assert_called_once_with("custom_prompt.txt")
mock_llm.assert_awaited_once_with(
text_input="Long text to summarize",
system_prompt="Custom system prompt",
response_model=str,
)

View file

@ -0,0 +1,157 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@pytest.fixture
def mock_edge():
"""Create a mock edge."""
edge = MagicMock(spec=Edge)
return edge
class TestGraphSummaryCompletionRetriever:
@pytest.mark.asyncio
async def test_init_defaults(self):
"""Test GraphSummaryCompletionRetriever initialization with defaults."""
retriever = GraphSummaryCompletionRetriever()
assert retriever.summarize_prompt_path == "summarize_search_results.txt"
assert retriever.user_prompt_path == "graph_context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 5
assert retriever.save_interaction is False
@pytest.mark.asyncio
async def test_init_custom_params(self):
"""Test GraphSummaryCompletionRetriever initialization with custom parameters."""
retriever = GraphSummaryCompletionRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
summarize_prompt_path="custom_summarize.txt",
system_prompt="Custom system prompt",
top_k=10,
save_interaction=True,
wide_search_top_k=200,
triplet_distance_penalty=2.5,
)
assert retriever.summarize_prompt_path == "custom_summarize.txt"
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.top_k == 10
assert retriever.save_interaction is True
@pytest.mark.asyncio
async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge):
"""Test resolve_edges_to_text calls super method and then summarizes."""
retriever = GraphSummaryCompletionRetriever(
summarize_prompt_path="custom_summarize.txt",
system_prompt="Custom system prompt",
)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Resolved edges text",
) as mock_super_resolve,
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Summarized text",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([mock_edge])
assert result == "Summarized text"
mock_super_resolve.assert_awaited_once_with([mock_edge])
mock_summarize.assert_awaited_once_with(
"Resolved edges text",
"custom_summarize.txt",
"Custom system prompt",
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge):
"""Test resolve_edges_to_text uses None for system_prompt when not provided."""
retriever = GraphSummaryCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Resolved edges text",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Summarized text",
) as mock_summarize,
):
await retriever.resolve_edges_to_text([mock_edge])
mock_summarize.assert_awaited_once_with(
"Resolved edges text",
"summarize_search_results.txt",
None,
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_empty_edges(self):
"""Test resolve_edges_to_text handles empty edges list."""
retriever = GraphSummaryCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Empty summary",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([])
assert result == "Empty summary"
mock_summarize.assert_awaited_once_with(
"",
"summarize_search_results.txt",
None,
)
@pytest.mark.asyncio
async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge):
"""Test resolve_edges_to_text handles multiple edges."""
retriever = GraphSummaryCompletionRetriever()
mock_edge2 = MagicMock(spec=Edge)
mock_edge3 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
return_value="Multiple edges resolved text",
),
patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
return_value="Multiple edges summarized",
) as mock_summarize,
):
result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3])
assert result == "Multiple edges summarized"
mock_summarize.assert_awaited_once_with(
"Multiple edges resolved text",
"summarize_search_results.txt",
None,
)

View file

@ -0,0 +1,312 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID, NAMESPACE_OID, uuid5
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
from cognee.modules.engine.models import NodeSet
@pytest.fixture
def mock_feedback_evaluation():
"""Create a mock feedback evaluation."""
evaluation = MagicMock(spec=UserFeedbackEvaluation)
evaluation.evaluation = MagicMock()
evaluation.evaluation.value = "positive"
evaluation.score = 4.5
return evaluation
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
engine.add_edges = AsyncMock()
engine.apply_feedback_weight = AsyncMock()
return engine
class TestUserQAFeedback:
@pytest.mark.asyncio
async def test_init_default(self):
"""Test UserQAFeedback initialization with default last_k."""
retriever = UserQAFeedback()
assert retriever.last_k == 1
@pytest.mark.asyncio
async def test_init_custom_last_k(self):
"""Test UserQAFeedback initialization with custom last_k."""
retriever = UserQAFeedback(last_k=5)
assert retriever.last_k == 5
@pytest.mark.asyncio
async def test_add_feedback_success_with_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback with relationships."""
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
return_value=[interaction_id_1, interaction_id_2]
)
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=2)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_awaited_once()
mock_index_edges.assert_awaited_once()
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
# Verify add_edges was called with correct relationships
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 2
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
assert call_args[0][1] == UUID(interaction_id_1)
assert call_args[0][2] == "gives_feedback_to"
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
assert call_args[0][3]["ontology_valid"] is False
# Verify apply_feedback_weight was called with correct node IDs
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
assert len(weight_call_args) == 2
assert interaction_id_1 in weight_call_args
assert interaction_id_2 in weight_call_args
@pytest.mark.asyncio
async def test_add_feedback_success_no_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback without relationships."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=1)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
# Should not call add_edges or index_graph_edges when no relationships
mock_graph_engine.add_edges.assert_not_awaited()
mock_index_edges.assert_not_awaited()
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
@pytest.mark.asyncio
async def test_add_feedback_creates_correct_feedback_node(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This was a negative experience"
mock_feedback_evaluation.evaluation.value = "negative"
mock_feedback_evaluation.score = -3.0
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
# Verify add_data_points was called with correct CogneeUserFeedback
call_args = mock_add_data.call_args[1]["data_points"]
assert len(call_args) == 1
feedback_node = call_args[0]
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
assert feedback_node.feedback == feedback_text
assert feedback_node.sentiment == "negative"
assert feedback_node.score == -3.0
assert isinstance(feedback_node.belongs_to_set, NodeSet)
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
@pytest.mark.asyncio
async def test_add_feedback_calls_llm_with_correct_prompt(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Great answer!"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
) as mock_llm,
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_llm.assert_awaited_once()
call_kwargs = mock_llm.call_args[1]
assert call_kwargs["text_input"] == feedback_text
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
assert call_kwargs["response_model"] == UserFeedbackEvaluation
@pytest.mark.asyncio
async def test_add_feedback_uses_last_k_parameter(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback(last_k=5)
await retriever.add_feedback(feedback_text)
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
@pytest.mark.asyncio
async def test_add_feedback_with_single_interaction(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback with single interaction ID."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
# Should create relationship for the interaction
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 1
assert call_args[0][1] == UUID(interaction_id)
@pytest.mark.asyncio
async def test_add_feedback_applies_weight_correctly(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback applies feedback weight with correct score."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
mock_feedback_evaluation.score = 4.5
feedback_text = "Positive feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
node_ids=[interaction_id], weight=4.5
)

View file

@ -81,3 +81,249 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_payload_text(mock_vector_engine):
"""Test get_context handles missing text in payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(KeyError):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_single_triplet(mock_vector_engine):
"""Test get_context with single triplet result."""
mock_result = MagicMock()
mock_result.payload = {"text": "Single triplet"}
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Single triplet"
@pytest.mark.asyncio
async def test_init_defaults():
"""Test TripletRetriever initialization with defaults."""
retriever = TripletRetriever()
assert retriever.user_prompt_path == "context_for_question.txt"
assert retriever.system_prompt_path == "answer_simple_question.txt"
assert retriever.top_k == 5 # Default is 5
assert retriever.system_prompt is None
@pytest.mark.asyncio
async def test_init_custom_params():
"""Test TripletRetriever initialization with custom parameters."""
retriever = TripletRetriever(
user_prompt_path="custom_user.txt",
system_prompt_path="custom_system.txt",
system_prompt="Custom prompt",
top_k=10,
)
assert retriever.user_prompt_path == "custom_user.txt"
assert retriever.system_prompt_path == "custom_system.txt"
assert retriever.system_prompt == "Custom prompt"
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context(mock_vector_engine):
"""Test get_completion uses provided context."""
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_vector_engine):
"""Test get_completion with session caching enabled."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.triplet_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.triplet_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion("test query", session_id="test_session")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
"""Test get_completion with session config but no user ID."""
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("test query")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_result = MagicMock()
mock_result.payload = {"text": "Test triplet"}
mock_vector_engine.has_collection.return_value = True
mock_vector_engine.search.return_value = [mock_result]
retriever = TripletRetriever()
with (
patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.triplet_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", response_model=TestModel)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_init_none_top_k():
"""Test TripletRetriever initialization with None top_k."""
retriever = TripletRetriever(top_k=None)
assert retriever.top_k == 5