From fd23c75c09d1cc22a406b72fc723a7975a6a4cca Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:44:41 +0100 Subject: [PATCH 01/10] chore: adds new Unit tests for retrievers --- .../retrieval/chunks_retriever_test.py | 322 ++++--- .../retrieval/conversation_history_test.py | 338 ++++++++ ...letion_retriever_context_extension_test.py | 552 +++++++++--- .../graph_completion_retriever_cot_test.py | 794 +++++++++++++++--- .../graph_completion_retriever_test.py | 793 +++++++++++++---- .../rag_completion_retriever_test.py | 462 ++++++---- .../retrieval/structured_output_test.py | 204 ----- .../retrieval/summaries_retriever_test.py | 312 ++++--- .../retrieval/temporal_retriever_test.py | 597 +++++++++++-- .../test_brute_force_triplet_search.py | 227 ++++- .../unit/modules/retrieval/test_completion.py | 343 ++++++++ ...test_graph_summary_completion_retriever.py | 157 ++++ .../retrieval/test_user_qa_feedback.py | 312 +++++++ .../retrieval/triplet_retriever_test.py | 246 ++++++ 14 files changed, 4454 insertions(+), 1205 deletions(-) delete mode 100644 cognee/tests/unit/modules/retrieval/structured_output_test.py create mode 100644 cognee/tests/unit/modules/retrieval/test_completion.py create mode 100644 cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py create mode 100644 cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 44786f79d..98bfd48fe 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -1,201 +1,183 @@ -import os import pytest -import pathlib -from typing import List -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from unittest.mock import AsyncMock, patch, MagicMock + from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine -class TestChunksRetriever: - @pytest.mark.asyncio - async def test_chunk_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of chunk context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = ChunksRetriever(top_k=5) - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - entities = [chunk1, chunk2, chunk3] + assert len(context) == 2 + assert context[0]["text"] == "Steve Rodger" + assert context[1]["text"] == "Mike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) - await add_data_points(entities) - retriever = ChunksRetriever() +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - context = await retriever.get_context("Mike") + retriever = ChunksRetriever() - assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - @pytest.mark.asyncio - async def test_chunk_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = ChunksRetriever() - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) + assert context == [] - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} - await add_data_points(entities) + mock_vector_engine.search.return_value = mock_results - retriever = ChunksRetriever(top_k=20) + retriever = ChunksRetriever(top_k=3) - context = await retriever.get_context("Christina") + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = ChunksRetriever() - retriever = ChunksRetriever() + provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] + completion = await retriever.get_completion("test query", context=provided_context) - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + assert completion == provided_context - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - context = await retriever.get_context("Christina Mayer") - assert len(context) == 0, "Found chunks when none should exist" +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test ChunksRetriever initialization with defaults.""" + retriever = ChunksRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test ChunksRetriever initialization with custom top_k.""" + retriever = ChunksRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test ChunksRetriever initialization with None top_k.""" + retriever = ChunksRetriever(top_k=None) + + assert retriever.top_k is None + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py index d464a99d8..f1ce9b370 100644 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ b/cognee/tests/unit/modules/retrieval/conversation_history_test.py @@ -152,3 +152,341 @@ class TestConversationHistoryUtils: assert result is True call_kwargs = mock_cache.add_qa.call_args.kwargs assert call_kwargs["session_id"] == "default_session" + + @pytest.mark.asyncio + async def test_save_conversation_history_no_user_id(self): + """Test save_conversation_history returns False when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_caching_disabled(self): + """Test save_conversation_history returns False when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_engine_none(self): + """Test save_conversation_history returns False when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_connection_error(self): + """Test save_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_generic_exception(self): + """Test save_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_get_conversation_history_no_user_id(self): + """Test get_conversation_history returns empty string when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_caching_disabled(self): + """Test get_conversation_history returns empty string when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_default_session(self): + """Test get_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + await get_conversation_history(session_id=None) + + mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session") + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_engine_none(self): + """Test get_conversation_history returns empty string when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_connection_error(self): + """Test get_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_generic_exception(self): + """Test get_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_missing_keys(self): + """Test get_conversation_history handles missing keys in history entries.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + }, + { + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + }, + {}, + ] + mock_cache = create_mock_cache_engine(mock_history) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "Unknown time" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 0e21fe351..6a9b07d38 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -1,177 +1,469 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( GraphCompletionContextExtensionRetriever, ) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -class TestGraphCompletionWithContextExtensionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_extension_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_simple", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_simple", - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionContextExtensionRetriever() - class Person(DataPoint): - name: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) + assert len(triplets) == 1 + assert triplets[0] == mock_edge - entities = [company1, company2, person1, person2, person3, person4, person5] - await add_data_points(entities) +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" + retriever = GraphCompletionContextExtensionRetriever() - retriever = GraphCompletionContextExtensionRetriever() + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" + retriever = GraphCompletionContextExtensionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) - answer = await retriever.get_completion("Who works at Canva?") + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 ) - @pytest.mark.asyncio - async def test_graph_completion_extension_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_complex", - ) - cognee.config.data_root_directory(data_directory_path) + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} +@pytest.mark.asyncio +async def test_get_completion_context_extension_rounds(mock_edge): + """Test get_completion with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Car(DataPoint): - brand: str - model: str - year: int + retriever = GraphCompletionContextExtensionRetriever() - class Location(DataPoint): - country: str - city: str + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) - class Home(DataPoint): - location: Location - rooms: int - sqm: int + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=["Resolved context", "Extended context"], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Query for extension, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None + completion = await retriever.get_completion("test query", context_extension_rounds=1) - company1 = Company(name="Figma") - company2 = Company(name="Canva") + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] +@pytest.mark.asyncio +async def test_get_completion_context_extension_stops_early(mock_edge): + """Test get_completion stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - person3 = Person(name="Jason Statham", works_for=company1) + retriever = GraphCompletionContextExtensionRetriever() - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever(top_k=20) - - context = await resolve_edges_to_text( - await retriever.get_context("Who works at Figma and drives Tesla?") + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=4 ) - print(context) + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - answer = await retriever.get_completion("Who works at Figma?") +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + retriever = GraphCompletionContextExtensionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", context_extension_rounds=1 ) - @pytest.mark.asyncio - async def test_get_graph_completion_extension_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph", + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_extension_context_on_empty_graph", + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, context_extension_rounds=1 ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) - retriever = GraphCompletionContextExtensionRetriever() - await setup() +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + retriever = GraphCompletionContextExtensionRetriever() - answer = await retriever.get_completion("Who works at Figma?") + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_zero_extension_rounds(mock_edge): + """Test get_completion with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=0) + + assert isinstance(completion, list) + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 206cfaf84..9f3147512 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,170 +1,688 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_cot_retriever import ( + GraphCompletionCotRetriever, + _as_answer_text, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.infrastructure.llm.LLMGateway import LLMGateway -class TestGraphCompletionCoTRetriever: - @pytest.mark.asyncio - async def test_graph_completion_cot_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionCotRetriever() - class Person(DataPoint): - name: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) + assert len(triplets) == 1 + assert triplets[0] == mock_edge - entities = [company1, company2, person1, person2, person3, person4, person5] - await add_data_points(entities) +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionCotRetriever initialization with custom parameters.""" + retriever = GraphCompletionCotRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + validation_user_prompt_path="custom_validation_user.txt", + validation_system_prompt_path="custom_validation_system.txt", + followup_system_prompt_path="custom_followup_system.txt", + followup_user_prompt_path="custom_followup_user.txt", + ) - retriever = GraphCompletionCotRetriever() + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.validation_user_prompt_path == "custom_validation_user.txt" + assert retriever.validation_system_prompt_path == "custom_validation_system.txt" + assert retriever.followup_system_prompt_path == "custom_followup_system.txt" + assert retriever.followup_user_prompt_path == "custom_followup_user.txt" - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() - answer = await retriever.get_completion("Who works at Canva?") + assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" + assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" + assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" + assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_with_context(mock_edge): + """Test _run_cot_completion round 0 with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=1, ) - @pytest.mark.asyncio - async def test_graph_completion_cot_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_cot_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_without_context(mock_edge): + """Test _run_cot_completion round 0 without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Car(DataPoint): - brand: str - model: str - year: int + retriever = GraphCompletionCotRetriever() - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=None, + max_iter=1, ) - @pytest.mark.asyncio - async def test_get_graph_completion_cot_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph", + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_multiple_rounds(mock_edge): + """Test _run_cot_completion with multiple rounds.""" + retriever = GraphCompletionCotRetriever() + + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "followup_question", + "validation_result2", + "followup_question2", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=2, ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_cot_context_on_empty_graph", + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_conversation_history(mock_edge): + """Test _run_cot_completion with conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="Previous conversation", + max_iter=1, ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) + assert completion == "Generated answer" + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") == "Previous conversation" - retriever = GraphCompletionCotRetriever() - await setup() +@pytest.mark.asyncio +async def test_run_cot_completion_with_response_model(mock_edge): + """Test _run_cot_completion with custom response model.""" + from pydantic import BaseModel - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + class TestModel(BaseModel): + answer: str - answer = await retriever.get_completion("Who works at Figma?") + retriever = GraphCompletionCotRetriever() - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + response_model=TestModel, + max_iter=1, ) + + assert isinstance(completion, TestModel) + assert completion.answer == "Test answer" + + +@pytest.mark.asyncio +async def test_run_cot_completion_empty_conversation_history(mock_edge): + """Test _run_cot_completion with empty conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="", + max_iter=1, + ) + + assert completion == "Generated answer" + # Verify conversation_history was passed as None when empty + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") is None + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + retriever = GraphCompletionCotRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None, max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_as_answer_text_with_typeerror(): + """Test _as_answer_text handles TypeError when json.dumps fails.""" + non_serializable = {1, 2, 3} + + result = _as_answer_text(non_serializable) + + assert isinstance(result, str) + assert result == str(non_serializable) + + +@pytest.mark.asyncio +async def test_as_answer_text_with_string(): + """Test _as_answer_text with string input.""" + result = _as_answer_text("test string") + assert result == "test string" + + +@pytest.mark.asyncio +async def test_as_answer_text_with_dict(): + """Test _as_answer_text with dictionary input.""" + test_dict = {"key": "value", "number": 42} + result = _as_answer_text(test_dict) + assert isinstance(result, str) + assert "key" in result + assert "value" in result + + +@pytest.mark.asyncio +async def test_as_answer_text_with_basemodel(): + """Test _as_answer_text with Pydantic BaseModel input.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + test_model = TestModel(answer="test answer") + result = _as_answer_text(test_model) + + assert isinstance(result, str) + assert "[Structured Response]" in result + assert "test answer" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index f462baced..c22f30fd0 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -1,223 +1,648 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -class TestGraphCompletionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - description: str +@pytest.mark.asyncio +async def test_get_triplets_success(mock_edge): + """Test successful retrieval of triplets.""" + retriever = GraphCompletionRetriever(top_k=5) - class Person(DataPoint): - name: str - description: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ) as mock_search: + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma", description="Figma is a company") - company2 = Company(name="Canva", description="Canvas is a company") - person1 = Person( - name="Steve Rodger", - description="This is description about Steve Rodger", - works_for=company1, - ) - person2 = Person( - name="Ike Loma", description="This is description about Ike Loma", works_for=company1 - ) - person3 = Person( - name="Jason Statham", - description="This is description about Jason Statham", - works_for=company1, - ) - person4 = Person( - name="Mike Broski", - description="This is description about Mike Broski", - works_for=company2, - ) - person5 = Person( - name="Christina Mayer", - description="This is description about Christina Mayer", - works_for=company2, + assert len(triplets) == 1 + assert triplets[0] == mock_edge + mock_search.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_triplets_empty_results(): + """Test that empty list is returned when no triplets are found.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ): + triplets = await retriever.get_triplets("test query") + + assert triplets == [] + + +@pytest.mark.asyncio +async def test_get_triplets_top_k_parameter(): + """Test that top_k parameter is passed to brute_force_triplet_search.""" + retriever = GraphCompletionRetriever(top_k=10) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ) as mock_search: + await retriever.get_triplets("test query") + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + +@pytest.mark.asyncio +async def test_get_context_success(mock_edge): + """Test successful retrieval of context.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + ): + context = await retriever.get_context("test query") + + assert isinstance(context, list) + assert len(context) == 1 + assert context[0] == mock_edge + + +@pytest.mark.asyncio +async def test_get_context_empty_results(): + """Test that empty list is returned when no context is found.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_empty_graph(): + """Test that empty list is returned when graph is empty.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=True) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_resolve_edges_to_text(mock_edge): + """Test resolve_edges_to_text method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionRetriever initialization with defaults.""" + retriever = GraphCompletionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.node_type is None + assert retriever.node_name is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionRetriever initialization with custom parameters.""" + retriever = GraphCompletionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test GraphCompletionRetriever initialization with None top_k.""" + retriever = GraphCompletionRetriever(top_k=None) + + assert retriever.top_k == 5 # None defaults to 5 + + +@pytest.mark.asyncio +async def test_convert_retrieved_objects_to_context(mock_edge): + """Test convert_retrieved_objects_to_context method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.convert_retrieved_objects_to_context([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_save_qa(mock_edge): + """Test save_qa method.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=["uuid1", "uuid2"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], ) - entities = [company1, company2, person1, person2, person3, person4, person5] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() - await add_data_points(entities) - retriever = GraphCompletionRetriever() +@pytest.mark.asyncio +async def test_save_qa_no_triplet_ids(mock_edge): + """Test save_qa when triplets have no extractable IDs.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + retriever = GraphCompletionRetriever() - # Ensure the top-level sections are present - assert "Nodes:" in context, "Missing 'Nodes:' section in context" - assert "Connections:" in context, "Missing 'Connections:' section in context" + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 - # --- Nodes headers --- - assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" - assert "Node: Figma" in context, "Missing node header for Figma" - assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" - assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" - assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" - assert "Node: Canva" in context, "Missing node header for Canva" - assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" - - # --- Node contents --- - assert ( - "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" - in context - ), "Description block for Steve Rodger altered" - assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( - "Description block for Figma altered" - ) - assert ( - "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" - in context - ), "Description block for Ike Loma altered" - assert ( - "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" - in context - ), "Description block for Jason Statham altered" - assert ( - "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" - in context - ), "Description block for Mike Broski altered" - assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( - "Description block for Canva altered" - ) - assert ( - "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" - in context - ), "Description block for Christina Mayer altered" - - # --- Connections --- - assert "Steve Rodger --[works_for]--> Figma" in context, ( - "Connection Steve Rodger→Figma missing or changed" - ) - assert "Ike Loma --[works_for]--> Figma" in context, ( - "Connection Ike Loma→Figma missing or changed" - ) - assert "Jason Statham --[works_for]--> Figma" in context, ( - "Connection Jason Statham→Figma missing or changed" - ) - assert "Mike Broski --[works_for]--> Canva" in context, ( - "Connection Mike Broski→Canva missing or changed" - ) - assert "Christina Mayer --[works_for]--> Canva" in context, ( - "Connection Christina Mayer→Canva missing or changed" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + return_value=None, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], ) - @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex" + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_qa_empty_triplets(): + """Test save_qa with empty triplets list.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[], ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - class Car(DataPoint): - brand: str - model: str - year: int +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_completion(mock_edge): + """Test get_completion with save_interaction but no completion.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Location(DataPoint): - country: str - city: str + retriever = GraphCompletionRetriever(save_interaction=True) - class Home(DataPoint): - location: Location - rooms: int - sqm: int + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=None, # No completion + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None + completion = await retriever.get_completion("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] is None - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - person3 = Person(name="Jason Statham", works_for=company1) + retriever = GraphCompletionRetriever(save_interaction=True) - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + completion = await retriever.get_completion("test query", context=None) - entities = [company1, company2, person1, person2, person3, person4, person5] + assert isinstance(completion, list) + assert len(completion) == 1 - await add_data_points(entities) - retriever = GraphCompletionRetriever(top_k=20) +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge): + """Test get_completion with save_interaction when all conditions are met (line 216).""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + retriever = GraphCompletionRetriever(save_interaction=True) - print(context) + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) + completion = await retriever.get_completion("test query", context=[mock_edge]) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_add_data.assert_awaited_once() diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 9bfed68f3..e998d419d 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -1,205 +1,321 @@ -import os -from typing import List import pytest -import pathlib -import cognee +from unittest.mock import AsyncMock, patch, MagicMock -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine -class TestRAGCompletionRetriever: - @pytest.mark.asyncio - async def test_rag_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski"} - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = CompletionRetriever(top_k=2) - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - entities = [chunk1, chunk2, chunk3] + assert context == "Steve Rodger\nMike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - await add_data_points(entities) - retriever = CompletionRetriever() +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - context = await retriever.get_context("Mike") + retriever = CompletionRetriever() - assert context == "Mike Broski", "Failed to get Mike Broski" + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - @pytest.mark.asyncio - async def test_rag_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = CompletionRetriever() - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) + assert context == "" - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(2)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} - await add_data_points(entities) + mock_vector_engine.search.return_value = mock_results - # TODO: top_k doesn't affect the output, it should be fixed. - retriever = CompletionRetriever(top_k=20) + retriever = CompletionRetriever(top_k=2) - context = await retriever.get_context("Christina") + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + assert context == "Chunk 0\nChunk 1" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - @pytest.mark.asyncio - async def test_get_rag_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) +@pytest.mark.asyncio +async def test_get_context_single_chunk(mock_vector_engine): + """Test get_context with single chunk result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single chunk text"} + mock_vector_engine.search.return_value = [mock_result] - retriever = CompletionRetriever() + retriever = CompletionRetriever() - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) + assert context == "Single chunk text" - context = await retriever.get_context("Christina Mayer") - assert context == "", "Returned context should be empty on an empty graph" + +@pytest.mark.asyncio +async def test_get_completion_without_session(mock_vector_engine): + """Test get_completion without session caching.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion with provided context.""" + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.completion_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test CompletionRetriever initialization with defaults.""" + retriever = CompletionRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 1 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test CompletionRetriever initialization with custom parameters.""" + retriever = CompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_missing_text_key(mock_vector_engine): + """Test get_context handles missing text key in payload.""" + mock_result = MagicMock() + mock_result.payload = {"other_key": "value"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py deleted file mode 100644 index 4ad3019ff..000000000 --- a/cognee/tests/unit/modules/retrieval/structured_output_test.py +++ /dev/null @@ -1,204 +0,0 @@ -import asyncio - -import pytest -import cognee -import pathlib -import os - -from pydantic import BaseModel -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.engine.models import Entity, EntityType -from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor -from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) -from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever -from cognee.modules.retrieval.completion_retriever import CompletionRetriever - - -class TestAnswer(BaseModel): - answer: str - explanation: str - - -def _assert_string_answer(answer: list[str]): - assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" - assert all(item.strip() for item in answer), "Items should not be empty" - - -def _assert_structured_answer(answer: list[TestAnswer]): - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" - assert all(x.answer.strip() for x in answer), "Answer text should not be empty" - assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" - - -async def _test_get_structured_graph_completion_cot(): - retriever = GraphCompletionCotRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion(): - retriever = GraphCompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_temporal(): - retriever = TemporalRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("When did Steve start working at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "When did Steve start working at Figma??", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_rag(): - retriever = CompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Where does Steve work?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Where does Steve work?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_context_extension(): - retriever = GraphCompletionContextExtensionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_entity_completion(): - retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who is Albert Einstein?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who is Albert Einstein?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -class TestStructuredOutputCompletion: - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - await _test_get_structured_graph_completion_cot() - await _test_get_structured_graph_completion() - await _test_get_structured_graph_completion_temporal() - await _test_get_structured_graph_completion_rag() - await _test_get_structured_graph_completion_context_extension() - await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index 5f4b93425..e552ac74a 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -1,159 +1,193 @@ -import os import pytest -import pathlib +from unittest.mock import AsyncMock, patch, MagicMock -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.tasks.summarization.models import TextSummary -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class TestSummariesRetriever: - @pytest.mark.asyncio - async def test_chunk_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of summary context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"} - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk1_summary = TextSummary( - text="S.R.", - made_from=chunk1, - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2_summary = TextSummary( - text="M.B.", - made_from=chunk2, - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3_summary = TextSummary( - text="C.M.", - made_from=chunk3, - ) - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk4_summary = TextSummary( - text="R.R.", - made_from=chunk4, - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5_summary = TextSummary( - text="H.Y.", - made_from=chunk5, - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6_summary = TextSummary( - text="C.H.", - made_from=chunk6, - ) + retriever = SummariesRetriever(top_k=5) - entities = [ - chunk1_summary, - chunk2_summary, - chunk3_summary, - chunk4_summary, - chunk5_summary, - chunk6_summary, - ] + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - await add_data_points(entities) + assert len(context) == 2 + assert context[0]["text"] == "S.R." + assert context[1]["text"] == "M.B." + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) - retriever = SummariesRetriever(top_k=20) - context = await retriever.get_context("Christina") +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" + retriever = SummariesRetriever() - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - retriever = SummariesRetriever() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no summaries are found.""" + mock_vector_engine.search.return_value = [] - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + retriever = SummariesRetriever() - vector_engine = get_vector_engine() - await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - context = await retriever.get_context("Christina Mayer") - assert context == [], "Returned context should be empty on an empty graph" + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Summary {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = SummariesRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = SummariesRetriever() + + provided_context = [{"text": "S.R."}, {"text": "M.B."}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test SummariesRetriever initialization with defaults.""" + retriever = SummariesRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test SummariesRetriever initialization with custom top_k.""" + retriever = SummariesRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_get_completion_with_kwargs(mock_vector_engine): + """Test get_completion accepts additional kwargs.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", extra_param="value") + + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index c3c6a47f6..1d2f4c84d 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,7 +1,12 @@ from types import SimpleNamespace import pytest +import os +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp +from cognee.infrastructure.llm import LLMGateway # Test TemporalRetriever initialization defaults and overrides @@ -140,85 +145,561 @@ async def test_filter_top_k_events_error_handling(): await tr.filter_top_k_events([{}], []) -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.collect_time_ids = AsyncMock() + engine.collect_events = AsyncMock() + return engine - async def extract_time_from_query(self, query: str): - if "both" in query: + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.embedding_engine = AsyncMock() + engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): + """Test get_context when time range is extracted from query.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + {"id": "e2", "description": "Event 2"}, + ] + } + ] + + mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened in 2024?") + + assert isinstance(context, str) + assert len(context) > 0 + assert "Event" in context + + +@pytest.mark.asyncio +async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): + """Test get_context falls back to triplets when no time is extracted.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return None, None + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_no_events_found(mock_graph_engine): + """Test get_context falls back to triplets when no events are found.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = [] + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] + retriever.extract_time_from_query = mock_extract_time - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" + context = await retriever.get_context("test query") - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() - async def _fake_graph_collect_events(self, ids): - return [ + +@pytest.mark.asyncio +async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_from.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened after 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_to.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened before 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(): + """Test get_completion uses provided context.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): + """Test get_completion with session caching enabled.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "What happened in 2024?", session_id="test_session" + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): + """Test get_completion with session config but no user ID.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): + """Test get_completion when get_context returns empty string.""" + retriever = TemporalRetriever() + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + ) as mock_get_vector, + patch.object(retriever, "filter_top_k_events", return_value=[]), + ): + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + mock_get_vector.return_value = mock_vector_engine + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ { "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, + {"id": "e1", "description": ""}, ] } ] - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] + with pytest.raises((UnboundLocalError, NameError)): + await retriever.get_completion("test query") - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) + class TestModel(BaseModel): + answer: str - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) + retriever = TemporalRetriever() - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "What happened in 2024?", response_model=TestModel ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) -# Test get_context fallback to triplets when no time is extracted @pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" +async def test_extract_time_from_query_relative_path(): + """Test extract_time_from_query with relative prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to -# Test get_context when time is extracted and vector ranking is applied @pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx +async def test_extract_time_from_query_absolute_path(): + """Test extract_time_from_query with absolute prompt path.""" + retriever = TemporalRetriever( + time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt" + ) + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.dirname", + return_value="/absolute/path/to", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.basename", + return_value="extract_query_time.txt", + ), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_with_none_values(): + """Test extract_time_from_query when interval has None values.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_interval = QueryInterval(starts_at=None, ends_at=None) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened?") + + assert time_from is None + assert time_to is None diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 3dc9f38d9..b7cbe08d7 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -1,12 +1,14 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, patch, MagicMock from cognee.modules.retrieval.utils.brute_force_triplet_search import ( brute_force_triplet_search, get_memory_fragment, + format_triplets, ) from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class MockScoredResult: @@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation @pytest.mark.asyncio async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" + """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( + + # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called + mock_fragment = MagicMock(spec=CogneeGraph) + mock_fragment.project_graph_from_db = AsyncMock( side_effect=EntityNotFoundError("Entity not found") ) - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", + return_value=mock_fragment, + ), ): - fragment = await get_memory_fragment() + result = await get_memory_fragment() - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 + # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) + assert result == mock_fragment + mock_fragment.project_graph_from_db.assert_awaited_once() @pytest.mark.asyncio @@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections(): call_kwargs = mock_get_fragment_fn.call_args[1] assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +def test_format_triplets(): + """Test format_triplets function.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "connects" in result + + +def test_format_triplets_with_none_values(): + """Test format_triplets filters out None values.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "None" not in result or result.count("None") == 0 + + +def test_format_triplets_with_nested_dict(): + """Test format_triplets handles nested dict attributes (lines 23-35).""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} + mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} + mock_edge.attributes = {"relationship_name": "relates_to"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_vector_engine_init_error(): + """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + ) as mock_get_vector_engine, + ): + mock_get_vector_engine.side_effect = Exception("Initialization error") + + with pytest.raises(RuntimeError, match="Initialization error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_error(): + """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock( + side_effect=[ + CollectionNotFoundError("Collection not found"), + [], + [], + ] + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=CogneeGraph(), + ), + ): + result = await brute_force_triplet_search( + query="test query", collections=["missing_collection", "existing_collection"] + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_generic_exception(): + """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + with pytest.raises(Exception, match="Generic error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): + """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test query", node_name=["Node1"]) + + assert mock_get_fragment.called + call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} + assert call_kwargs.get("relevant_ids_to_filter") is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_at_top_level(): + """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + result = await brute_force_triplet_search(query="test query") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py new file mode 100644 index 000000000..9a836c2cc --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_completion.py @@ -0,0 +1,343 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from typing import Type + + +class TestGenerateCompletion: + @pytest.mark.asyncio + async def test_generate_completion_with_system_prompt(self): + """Test generate_completion with provided system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="Custom system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_without_system_prompt(self): + """Test generate_completion reads system_prompt from file when not provided.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history(self): + """Test generate_completion includes conversation_history in system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "System prompt from file" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self): + """Test generate_completion includes conversation_history with custom system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "Custom system prompt" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_response_model(self): + """Test generate_completion with custom response_model.""" + mock_response_model = MagicMock() + mock_llm_response = {"answer": "Generated answer"} + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + response_model=mock_response_model, + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=mock_response_model, + ) + + @pytest.mark.asyncio + async def test_generate_completion_render_prompt_args(self): + """Test generate_completion passes correct args to render_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ) as mock_render, + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ), + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + mock_render.assert_called_once_with( + "user_prompt.txt", + {"question": "What is AI?", "context": "AI is artificial intelligence"}, + ) + + +class TestSummarizeText: + @pytest.mark.asyncio + async def test_summarize_text_with_system_prompt(self): + """Test summarize_text with provided system_prompt.""" + mock_llm_response = "Summary text" + + with patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm: + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + system_prompt="Custom summary prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom summary prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_without_system_prompt(self): + """Test summarize_text reads system_prompt from file when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_default_prompt_path(self): + """Test summarize_text uses default prompt path when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Default system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text(text="Long text to summarize") + + assert result == mock_llm_response + mock_read.assert_called_once_with("summarize_search_results.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Default system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_custom_prompt_path(self): + """Test summarize_text uses custom prompt path when provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Custom system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="custom_prompt.txt", + ) + + assert result == mock_llm_response + mock_read.assert_called_once_with("custom_prompt.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom system prompt", + response_model=str, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py new file mode 100644 index 000000000..2af10da5e --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +class TestGraphSummaryCompletionRetriever: + @pytest.mark.asyncio + async def test_init_defaults(self): + """Test GraphSummaryCompletionRetriever initialization with defaults.""" + retriever = GraphSummaryCompletionRetriever() + + assert retriever.summarize_prompt_path == "summarize_search_results.txt" + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 + assert retriever.save_interaction is False + + @pytest.mark.asyncio + async def test_init_custom_params(self): + """Test GraphSummaryCompletionRetriever initialization with custom parameters.""" + retriever = GraphSummaryCompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + top_k=10, + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=2.5, + ) + + assert retriever.summarize_prompt_path == "custom_summarize.txt" + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.top_k == 10 + assert retriever.save_interaction is True + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge): + """Test resolve_edges_to_text calls super method and then summarizes.""" + retriever = GraphSummaryCompletionRetriever( + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + ) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ) as mock_super_resolve, + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Summarized text" + mock_super_resolve.assert_awaited_once_with([mock_edge]) + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "custom_summarize.txt", + "Custom system prompt", + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge): + """Test resolve_edges_to_text uses None for system_prompt when not provided.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + await retriever.resolve_edges_to_text([mock_edge]) + + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_empty_edges(self): + """Test resolve_edges_to_text handles empty edges list.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Empty summary", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([]) + + assert result == "Empty summary" + mock_summarize.assert_awaited_once_with( + "", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge): + """Test resolve_edges_to_text handles multiple edges.""" + retriever = GraphSummaryCompletionRetriever() + + mock_edge2 = MagicMock(spec=Edge) + mock_edge3 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Multiple edges resolved text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Multiple edges summarized", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3]) + + assert result == "Multiple edges summarized" + mock_summarize.assert_awaited_once_with( + "Multiple edges resolved text", + "summarize_search_results.txt", + None, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py new file mode 100644 index 000000000..a1e746bb9 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py @@ -0,0 +1,312 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID, NAMESPACE_OID, uuid5 + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment +from cognee.modules.engine.models import NodeSet + + +@pytest.fixture +def mock_feedback_evaluation(): + """Create a mock feedback evaluation.""" + evaluation = MagicMock(spec=UserFeedbackEvaluation) + evaluation.evaluation = MagicMock() + evaluation.evaluation.value = "positive" + evaluation.score = 4.5 + return evaluation + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + engine.add_edges = AsyncMock() + engine.apply_feedback_weight = AsyncMock() + return engine + + +class TestUserQAFeedback: + @pytest.mark.asyncio + async def test_init_default(self): + """Test UserQAFeedback initialization with default last_k.""" + retriever = UserQAFeedback() + assert retriever.last_k == 1 + + @pytest.mark.asyncio + async def test_init_custom_last_k(self): + """Test UserQAFeedback initialization with custom last_k.""" + retriever = UserQAFeedback(last_k=5) + assert retriever.last_k == 5 + + @pytest.mark.asyncio + async def test_add_feedback_success_with_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback with relationships.""" + interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock( + return_value=[interaction_id_1, interaction_id_2] + ) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=2) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + mock_index_edges.assert_awaited_once() + mock_graph_engine.apply_feedback_weight.assert_awaited_once() + + # Verify add_edges was called with correct relationships + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 2 + assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text) + assert call_args[0][1] == UUID(interaction_id_1) + assert call_args[0][2] == "gives_feedback_to" + assert call_args[0][3]["relationship_name"] == "gives_feedback_to" + assert call_args[0][3]["ontology_valid"] is False + + # Verify apply_feedback_weight was called with correct node IDs + weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"] + assert len(weight_call_args) == 2 + assert interaction_id_1 in weight_call_args + assert interaction_id_2 in weight_call_args + + @pytest.mark.asyncio + async def test_add_feedback_success_no_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback without relationships.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=1) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + # Should not call add_edges or index_graph_edges when no relationships + mock_graph_engine.add_edges.assert_not_awaited() + mock_index_edges.assert_not_awaited() + mock_graph_engine.apply_feedback_weight.assert_not_awaited() + + @pytest.mark.asyncio + async def test_add_feedback_creates_correct_feedback_node( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback creates CogneeUserFeedback with correct attributes.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This was a negative experience" + mock_feedback_evaluation.evaluation.value = "negative" + mock_feedback_evaluation.score = -3.0 + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + # Verify add_data_points was called with correct CogneeUserFeedback + call_args = mock_add_data.call_args[1]["data_points"] + assert len(call_args) == 1 + feedback_node = call_args[0] + assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text) + assert feedback_node.feedback == feedback_text + assert feedback_node.sentiment == "negative" + assert feedback_node.score == -3.0 + assert isinstance(feedback_node.belongs_to_set, NodeSet) + assert feedback_node.belongs_to_set.name == "UserQAFeedbacks" + + @pytest.mark.asyncio + async def test_add_feedback_calls_llm_with_correct_prompt( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback calls LLM with correct sentiment analysis prompt.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Great answer!" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ) as mock_llm, + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_llm.assert_awaited_once() + call_kwargs = mock_llm.call_args[1] + assert call_kwargs["text_input"] == feedback_text + assert "sentiment analysis assistant" in call_kwargs["system_prompt"] + assert call_kwargs["response_model"] == UserFeedbackEvaluation + + @pytest.mark.asyncio + async def test_add_feedback_uses_last_k_parameter( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback uses last_k parameter when getting interaction IDs.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback(last_k=5) + await retriever.add_feedback(feedback_text) + + mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5) + + @pytest.mark.asyncio + async def test_add_feedback_with_single_interaction( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback with single interaction ID.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + # Should create relationship for the interaction + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0][1] == UUID(interaction_id) + + @pytest.mark.asyncio + async def test_add_feedback_applies_weight_correctly( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback applies feedback weight with correct score.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + mock_feedback_evaluation.score = 4.5 + + feedback_text = "Positive feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_graph_engine.apply_feedback_weight.assert_awaited_once_with( + node_ids=[interaction_id], weight=4.5 + ) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index d79aca428..83612c7aa 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -81,3 +81,249 @@ async def test_get_context_collection_not_found_error(mock_vector_engine): ): with pytest.raises(NoDataError, match="No data found"): await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_payload_text(mock_vector_engine): + """Test get_context handles missing text in payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_single_triplet(mock_vector_engine): + """Test get_context with single triplet result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single triplet"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single triplet" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test TripletRetriever initialization with defaults.""" + retriever = TripletRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 # Default is 5 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test TripletRetriever initialization with custom parameters.""" + retriever = TripletRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion uses provided context.""" + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test TripletRetriever initialization with None top_k.""" + retriever = TripletRetriever(top_k=None) + + assert retriever.top_k == 5 From fa035f42f40715b616e6c926b00515dbb35c80da Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:47:58 +0100 Subject: [PATCH 02/10] chore: adds back accidentally deleted structured output test --- .../retrieval/structured_output_test.py | 204 ++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/structured_output_test.py diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py new file mode 100644 index 000000000..4ad3019ff --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -0,0 +1,204 @@ +import asyncio + +import pytest +import cognee +import pathlib +import os + +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +def _assert_string_answer(answer: list[str]): + assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" + assert all(item.strip() for item in answer), "Items should not be empty" + + +def _assert_structured_answer(answer: list[TestAnswer]): + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" + assert all(x.answer.strip() for x in answer), "Answer text should not be empty" + assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" + + +async def _test_get_structured_graph_completion_cot(): + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion(): + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_temporal(): + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_rag(): + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_context_extension(): + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_entity_completion(): + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +class TestStructuredOutputCompletion: + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() From c61ff60e40eedcf892b4ed6a08621aa2a7adfcc4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:37:33 +0100 Subject: [PATCH 03/10] feat: add unit tests for get_search_type_tools --- .../search/test_get_search_type_tools.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_get_search_type_tools.py diff --git a/cognee/tests/unit/modules/search/test_get_search_type_tools.py b/cognee/tests/unit/modules/search/test_get_search_type_tools.py new file mode 100644 index 000000000..15b489bfa --- /dev/null +++ b/cognee/tests/unit/modules/search/test_get_search_type_tools.py @@ -0,0 +1,221 @@ +import pytest + +from cognee.modules.search.exceptions import UnsupportedSearchTypeError +from cognee.modules.search.types import SearchType + + +class _DummyCommunityRetriever: + def __init__(self, *args, **kwargs): + self.kwargs = kwargs + + def get_completion(self, *args, **kwargs): + return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs} + + def get_context(self, *args, **kwargs): + return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs} + + +@pytest.mark.asyncio +async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.chunks_retriever import ChunksRetriever + + async def _fake_select_search_type(query_text: str): + assert query_text == "hello" + return SearchType.CHUNKS + + monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type) + + tools = await mod.get_search_type_tools(SearchType.FEELING_LUCKY, query_text="hello") + + assert len(tools) == 2 + assert all(callable(t) for t in tools) + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is ChunksRetriever + assert tools[1].__self__.__class__ is ChunksRetriever + + +@pytest.mark.asyncio +async def test_disallowed_cypher_search_types_raise(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.CYPHER, query_text="MATCH (n) RETURN n") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="Find nodes") + + +@pytest.mark.asyncio +async def test_allowed_cypher_search_types_return_tools(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(SearchType.CYPHER, query_text="q") + assert len(tools) == 2 + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is CypherSearchRetriever + assert tools[1].__self__.__class__ is CypherSearchRetriever + + +@pytest.mark.asyncio +async def test_registered_community_retriever_is_used(monkeypatch): + """ + Integration point: community retrievers are loaded from the registry module and should + override the default mapping when present. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval import registered_community_retrievers as registry + + monkeypatch.setattr( + registry, + "registered_community_retrievers", + {SearchType.SUMMARIES: _DummyCommunityRetriever}, + ) + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=7) + + assert len(tools) == 2 + assert tools[0].__self__.__class__ is _DummyCommunityRetriever + assert tools[0].__self__.kwargs["top_k"] == 7 + assert tools[1].__self__.__class__ is _DummyCommunityRetriever + assert tools[1].__self__.kwargs["top_k"] == 7 + + +@pytest.mark.asyncio +async def test_unknown_query_type_raises_unsupported(): + import cognee.modules.search.methods.get_search_type_tools as mod + + with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"): + await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_default_mapping_passes_top_k_to_retrievers(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=4) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is SummariesRetriever + assert tools[1].__self__.__class__ is SummariesRetriever + assert tools[0].__self__.top_k == 4 + assert tools[1].__self__.top_k == 4 + + +@pytest.mark.asyncio +async def test_chunks_lexical_returns_jaccard_tools(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever + + tools = await mod.get_search_type_tools(SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is JaccardChunksRetriever + assert tools[1].__self__.__class__ is JaccardChunksRetriever + assert tools[0].__self__ is tools[1].__self__ + + +@pytest.mark.asyncio +async def test_coding_rules_uses_node_name_as_rules_nodeset_name(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever + + tools = await mod.get_search_type_tools(SearchType.CODING_RULES, query_text="q", node_name=[]) + assert len(tools) == 1 + assert tools[0].__name__ == "get_existing_rules" + assert tools[0].__self__.__class__ is CodingRulesRetriever + # Empty list should default to ["coding_agent_rules"] + assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"] + + +@pytest.mark.asyncio +async def test_feedback_uses_last_k(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback + + tools = await mod.get_search_type_tools(SearchType.FEEDBACK, query_text="q", last_k=11) + assert len(tools) == 1 + assert tools[0].__name__ == "add_feedback" + assert tools[0].__self__.__class__ is UserQAFeedback + assert tools[0].__self__.last_k == 11 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "query_type, expected_class_name, expected_method_names", + [ + (SearchType.CHUNKS, "ChunksRetriever", ("get_completion", "get_context")), + (SearchType.RAG_COMPLETION, "CompletionRetriever", ("get_completion", "get_context")), + (SearchType.TRIPLET_COMPLETION, "TripletRetriever", ("get_completion", "get_context")), + ( + SearchType.GRAPH_COMPLETION, + "GraphCompletionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_COT, + "GraphCompletionCotRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, + "GraphCompletionContextExtensionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_SUMMARY_COMPLETION, + "GraphSummaryCompletionRetriever", + ("get_completion", "get_context"), + ), + (SearchType.TEMPORAL, "TemporalRetriever", ("get_completion", "get_context")), + ( + SearchType.NATURAL_LANGUAGE, + "NaturalLanguageRetriever", + ("get_completion", "get_context"), + ), + ], +) +async def test_tool_construction_for_supported_search_types( + monkeypatch, query_type, expected_class_name, expected_method_names +): + import cognee.modules.search.methods.get_search_type_tools as mod + + # Natural language is guarded by ALLOW_CYPHER_QUERY too + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(query_type, query_text="q") + + assert len(tools) == 2 + assert tools[0].__name__ == expected_method_names[0] + assert tools[1].__name__ == expected_method_names[1] + assert tools[0].__self__.__class__.__name__ == expected_class_name + assert tools[1].__self__.__class__.__name__ == expected_class_name + + +@pytest.mark.asyncio +async def test_some_completion_tools_are_callable_without_backends(monkeypatch): + """ + "Making search tools" should include that the returned callables are usable. + For retrievers that accept an explicit `context`, we can call get_completion without touching + DB/LLM backends. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + for query_type in [ + SearchType.CHUNKS, + SearchType.SUMMARIES, + SearchType.CYPHER, + SearchType.NATURAL_LANGUAGE, + ]: + tools = await mod.get_search_type_tools(query_type, query_text="q") + completion = tools[0] + result = await completion("q", context=["ok"]) # type: ignore[call-arg] + assert result == ["ok"] From 89ef7d7d151cb18e81ad676ca1c205d6995e61d4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:41:13 +0100 Subject: [PATCH 04/10] feat: adds integration test for community registered retriever case --- .../test_get_search_type_tools_integration.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 cognee/tests/integration/search/test_get_search_type_tools_integration.py diff --git a/cognee/tests/integration/search/test_get_search_type_tools_integration.py b/cognee/tests/integration/search/test_get_search_type_tools_integration.py new file mode 100644 index 000000000..4380a3bdb --- /dev/null +++ b/cognee/tests/integration/search/test_get_search_type_tools_integration.py @@ -0,0 +1,36 @@ +import pytest + +from cognee.modules.search.types import SearchType + + +class _DummyCompletionContextRetriever: + def __init__(self, *args, **kwargs): + self.kwargs = kwargs + + def get_completion(self, *args, **kwargs): + return None + + def get_context(self, *args, **kwargs): + return None + + +@pytest.mark.asyncio +async def test_community_registry_is_consulted(monkeypatch): + """ + This test covers the dynamic import + lookup of community retrievers in + cognee.modules.retrieval.registered_community_retrievers. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval import registered_community_retrievers as registry + + monkeypatch.setattr( + registry, + "registered_community_retrievers", + {SearchType.NATURAL_LANGUAGE: _DummyCompletionContextRetriever}, + ) + + tools = await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="q", top_k=9) + + assert len(tools) == 2 + assert tools[0].__self__.kwargs["top_k"] == 9 + assert tools[1].__self__.kwargs["top_k"] == 9 From 48c2040f3ddc3ba3dcdd0b7f3f7b9005aa3350c3 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:45:32 +0100 Subject: [PATCH 05/10] Delete test_get_search_type_tools_integration.py --- .../test_get_search_type_tools_integration.py | 36 ------------------- 1 file changed, 36 deletions(-) delete mode 100644 cognee/tests/integration/search/test_get_search_type_tools_integration.py diff --git a/cognee/tests/integration/search/test_get_search_type_tools_integration.py b/cognee/tests/integration/search/test_get_search_type_tools_integration.py deleted file mode 100644 index 4380a3bdb..000000000 --- a/cognee/tests/integration/search/test_get_search_type_tools_integration.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from cognee.modules.search.types import SearchType - - -class _DummyCompletionContextRetriever: - def __init__(self, *args, **kwargs): - self.kwargs = kwargs - - def get_completion(self, *args, **kwargs): - return None - - def get_context(self, *args, **kwargs): - return None - - -@pytest.mark.asyncio -async def test_community_registry_is_consulted(monkeypatch): - """ - This test covers the dynamic import + lookup of community retrievers in - cognee.modules.retrieval.registered_community_retrievers. - """ - import cognee.modules.search.methods.get_search_type_tools as mod - from cognee.modules.retrieval import registered_community_retrievers as registry - - monkeypatch.setattr( - registry, - "registered_community_retrievers", - {SearchType.NATURAL_LANGUAGE: _DummyCompletionContextRetriever}, - ) - - tools = await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="q", top_k=9) - - assert len(tools) == 2 - assert tools[0].__self__.kwargs["top_k"] == 9 - assert tools[1].__self__.kwargs["top_k"] == 9 From 7892b48afe08ea338a753cb6ddc45e48fb884ff7 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:59:15 +0100 Subject: [PATCH 06/10] Update test_get_search_type_tools.py --- .../unit/modules/search/test_get_search_type_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cognee/tests/unit/modules/search/test_get_search_type_tools.py b/cognee/tests/unit/modules/search/test_get_search_type_tools.py index 15b489bfa..3748a4e4b 100644 --- a/cognee/tests/unit/modules/search/test_get_search_type_tools.py +++ b/cognee/tests/unit/modules/search/test_get_search_type_tools.py @@ -93,7 +93,7 @@ async def test_unknown_query_type_raises_unsupported(): import cognee.modules.search.methods.get_search_type_tools as mod with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"): - await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") # type: ignore[arg-type] + await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") @pytest.mark.asyncio @@ -130,7 +130,7 @@ async def test_coding_rules_uses_node_name_as_rules_nodeset_name(): assert len(tools) == 1 assert tools[0].__name__ == "get_existing_rules" assert tools[0].__self__.__class__ is CodingRulesRetriever - # Empty list should default to ["coding_agent_rules"] + assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"] @@ -186,7 +186,6 @@ async def test_tool_construction_for_supported_search_types( ): import cognee.modules.search.methods.get_search_type_tools as mod - # Natural language is guarded by ALLOW_CYPHER_QUERY too monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") tools = await mod.get_search_type_tools(query_type, query_text="q") @@ -217,5 +216,5 @@ async def test_some_completion_tools_are_callable_without_backends(monkeypatch): ]: tools = await mod.get_search_type_tools(query_type, query_text="q") completion = tools[0] - result = await completion("q", context=["ok"]) # type: ignore[call-arg] + result = await completion("q", context=["ok"]) assert result == ["ok"] From 789fa9079052b0b04274cd992a294c98b3668de2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:39:31 +0100 Subject: [PATCH 07/10] chore: covering search.py behavior with unit tests --- .../tests/unit/modules/search/test_search.py | 465 ++++++++++++++++++ 1 file changed, 465 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_search.py diff --git a/cognee/tests/unit/modules/search/test_search.py b/cognee/tests/unit/modules/search/test_search.py new file mode 100644 index 000000000..9c36f07da --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search.py @@ -0,0 +1,465 @@ +import types +from uuid import uuid4 + +import pytest + +from cognee.modules.search.types import SearchType + + +def _make_user(user_id: str = "u1", tenant_id=None): + return types.SimpleNamespace(id=user_id, tenant_id=tenant_id) + + +def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None): + return types.SimpleNamespace( + id=dataset_id or uuid4(), + name=name, + tenant_id=tenant_id, + owner_id=owner_id or uuid4(), + ) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_side_effect_boundaries(monkeypatch, search_mod): + """ + Keep production logic; patch only unavoidable side-effect boundaries. + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + async def dummy_prepare_search_result(search_result): + # search() and helpers mostly exchange tuples: (result, context, datasets) + if isinstance(search_result, tuple) and len(search_result) == 3: + result, context, datasets = search_result + return {"result": result, "context": context, "graphs": {}, "datasets": datasets} + return {"result": None, "context": None, "graphs": {}, "datasets": []} + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result) + + yield + + +@pytest.mark.asyncio +async def test_search_no_access_control_flattens_single_list_result(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (["r"], ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_non_list_result_returns_list(monkeypatch, search_mod): + """ + Covers the non-flattening back-compat branch in `search()`: if the single returned result is + not a list, `search()` returns a list of results instead of flattening. + """ + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return ("r", ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_only_context_returns_context(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (None, ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + only_context=True, + ) + + assert out == ["ctx"] + + +@pytest.mark.asyncio +async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**kwargs): + assert kwargs["dataset_ids"] == [ds.id] + return [("r", ["ctx"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out == [ + { + "search_result": ["r"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_dataset_shaped_dicts( + monkeypatch, search_mod +): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out == [ + { + "search_result": ["ctx"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_use_combined_context_returns_combined_model( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1", tenant_id="t1") + ds2 = _make_dataset(name="ds2", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return ("answer", {"k": "v"}, [ds1, ds2]) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds1.id, ds2.id], + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"k": "v"} + assert out.graphs == {} + assert [d.id for d in out.datasets] == [ds1.id, ds2.id] + + +@pytest.mark.asyncio +async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds] + + expected = [("r", ["ctx"], [ds])] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True + return expected + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + + out = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds.id], + use_combined_context=False, + only_context=False, + ) + + assert out == expected + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["only_context"] is True + return [(None, ["a"], [ds1]), (None, ["b"], [ds2])] + + seen = {} + + async def dummy_get_completion(query_text, context, session_id=None): + seen["query_text"] = query_text + seen["context"] = context + seen["session_id"] = session_id + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, lambda *_a, **_k: None] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + session_id="s1", + ) + + assert combined_context == "a\nb" + assert completion == ["answer"] + assert datasets == [ds1, ds2] + assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"} + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_keeps_non_string_context( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + class DummyEdge: + pass + + e1, e2 = DummyEdge(), DummyEdge() + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**_kwargs): + return [(None, [e1], [ds1]), (None, [e2], [ds2])] + + async def dummy_get_completion(query_text, context, session_id=None): + assert query_text == "q" + assert context == [e1, e2] + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + ) + + assert combined_context == [e1, e2] + assert completion == ["answer"] + assert datasets == [ds1, ds2] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches( + monkeypatch, search_mod +): + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return True + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_dataset_data(dataset_id): + return [1] if dataset_id == ds1.id else [] + + calls = {"completion": 0, "context": 0} + + async def dummy_get_context(_query_text: str): + calls["context"] += 1 + return ["ctx"] + + async def dummy_get_completion(_query_text: str, _context, session_id=None): + calls["completion"] += 1 + assert session_id == "s1" + return ["r"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + monkeypatch.setattr("cognee.modules.data.methods.get_dataset_data", dummy_get_dataset_data) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds1, ds2], + query_type=SearchType.CHUNKS, + query_text="q", + context=["pre_ctx"], + session_id="s1", + ) + + assert out == [(["r"], ["pre_ctx"], [ds1]), (["r"], ["pre_ctx"], [ds2])] + assert calls == {"completion": 2, "context": 0} + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_only_context_true(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_context(query_text: str): + assert query_text == "q" + return ["ctx"] + + async def dummy_get_completion(*_args, **_kwargs): + raise AssertionError("Completion should not be called when only_context=True") + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CHUNKS, + query_text="q", + only_context=True, + ) + + assert out == [(None, ["ctx"], [ds])] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_unknown_tool_path(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_unknown_tool(query_text: str): + assert query_text == "q" + return ["u"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_unknown_tool] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CODING_RULES, + query_text="q", + ) + + assert out == [(["u"], "", [ds])] From 4ff2a35476a3707630c3fbc783a5357907183b26 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 17:33:20 +0100 Subject: [PATCH 08/10] chore: moves unit tests into their correct directory --- .../{integration => unit/modules}/retrieval/test_completion.py | 0 .../modules}/retrieval/test_graph_summary_completion_retriever.py | 0 .../modules}/retrieval/test_user_qa_feedback.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename cognee/tests/{integration => unit/modules}/retrieval/test_completion.py (100%) rename cognee/tests/{integration => unit/modules}/retrieval/test_graph_summary_completion_retriever.py (100%) rename cognee/tests/{integration => unit/modules}/retrieval/test_user_qa_feedback.py (100%) diff --git a/cognee/tests/integration/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py similarity index 100% rename from cognee/tests/integration/retrieval/test_completion.py rename to cognee/tests/unit/modules/retrieval/test_completion.py diff --git a/cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py similarity index 100% rename from cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py rename to cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py diff --git a/cognee/tests/integration/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py similarity index 100% rename from cognee/tests/integration/retrieval/test_user_qa_feedback.py rename to cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py From 18d0a418505dd885d6ab7e2f8efd7a33091471be Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 17:49:43 +0100 Subject: [PATCH 09/10] Update test_search.py --- cognee/tests/unit/modules/search/test_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/tests/unit/modules/search/test_search.py b/cognee/tests/unit/modules/search/test_search.py index 9c36f07da..175fd9aa4 100644 --- a/cognee/tests/unit/modules/search/test_search.py +++ b/cognee/tests/unit/modules/search/test_search.py @@ -39,7 +39,6 @@ def _patch_side_effect_boundaries(monkeypatch, search_mod): return None async def dummy_prepare_search_result(search_result): - # search() and helpers mostly exchange tuples: (result, context, datasets) if isinstance(search_result, tuple) and len(search_result) == 3: result, context, datasets = search_result return {"result": result, "context": context, "graphs": {}, "datasets": datasets} From 94d5175570a2358b242feb76529577c5ee6024e2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:34:57 +0100 Subject: [PATCH 10/10] feat: adds unit test for the prepare search result - search contract --- ...t_search_prepare_search_result_contract.py | 296 ++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py diff --git a/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py new file mode 100644 index 000000000..8700e6a1b --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py @@ -0,0 +1,296 @@ +## The Objective of these tests is to cover the search - prepare search results behavior (later to be removed) + +import types +from uuid import uuid4 + +import pytest +from pydantic import BaseModel + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node +from cognee.modules.search.types import SearchType + + +class DummyDataset(BaseModel): + id: object + name: str + tenant_id: str | None = None + owner_id: object + + +def _ds(name="ds1", tenant_id="t1"): + return DummyDataset(id=uuid4(), name=name, tenant_id=tenant_id, owner_id=uuid4()) + + +def _edge(rel="rel", n1="A", n2="B"): + node1 = Node(str(uuid4()), attributes={"type": "Entity", "name": n1}) + node2 = Node(str(uuid4()), attributes={"type": "Entity", "name": n2}) + return Edge(node1, node2, attributes={"relationship_name": rel}) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_search_side_effects(monkeypatch, search_mod): + """ + These tests validate prepare_search_result behavior *through* search.py. + We only patch unavoidable side effects (telemetry + query/result logging). + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + + yield + + +@pytest.fixture(autouse=True) +def _patch_resolve_edges_to_text(monkeypatch): + """ + Keep graph-text conversion deterministic and lightweight. + """ + import importlib + + psr_mod = importlib.import_module("cognee.modules.search.utils.prepare_search_result") + + async def dummy_resolve_edges_to_text(_edges): + return "EDGE_TEXT" + + monkeypatch.setattr(psr_mod, "resolve_edges_to_text", dummy_resolve_edges_to_text) + + yield + + +@pytest.mark.asyncio +async def test_search_access_control_edges_context_produces_graphs_and_context_map( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + context = [_edge("likes")] + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], context, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["dataset_name"] == "ds1" + assert out[0]["dataset_tenant_id"] == "t1" + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["graphs"]["ds1"]["nodes"] + assert out[0]["graphs"]["ds1"]["edges"] + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_insights_context_produces_graphs_and_null_result( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + insights = [ + ( + {"id": "n1", "type": "Entity", "name": "Alice"}, + {"relationship_name": "knows"}, + {"id": "n2", "type": "Entity", "name": "Bob"}, + ) + ] + + async def dummy_authorized_search(**_kwargs): + return [(["something"], insights, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["search_result"] is None + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_context_text_map(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, ["a", "b"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out[0]["search_result"] == [{"ds1": "a\nb"}] + + +@pytest.mark.asyncio +async def test_search_access_control_results_edges_become_graph_result(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + results = [_edge("connected_to")] + + async def dummy_authorized_search(**_kwargs): + return [(results, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert isinstance(out[0]["search_result"][0], dict) + assert "nodes" in out[0]["search_result"][0] + assert "edges" in out[0]["search_result"][0] + + +@pytest.mark.asyncio +async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return ("answer", "ctx", []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"all available datasets": "ctx"} + assert out.datasets[0].name == "all available datasets" + + +@pytest.mark.asyncio +async def test_search_access_control_context_str_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is str) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "plain context", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_context_empty_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is empty list) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], [], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_multiple_results_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(result list length > 1) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["r1", "r2"], "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["search_result"] == [["r1", "r2"]] + + +@pytest.mark.asyncio +async def test_search_access_control_defaults_empty_datasets(monkeypatch, search_mod): + """ + Covers prepare_search_result(datasets empty list) through search(). + + Note: in access-control mode, search.py expects datasets[0] to have `tenant_id`, + but prepare_search_result defaults to SearchResultDataset which doesn't define it. + We assert the current behavior (it raises) so refactors don't silently change it. + """ + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "ctx", [])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + with pytest.raises(AttributeError, match="tenant_id"): + await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + )