diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index b77ff0f0e..ed4bdcefc 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,5 +1,3 @@ -from typing import Optional - from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py new file mode 100644 index 000000000..85b33060d --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -0,0 +1,120 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever + + +class TestChunksRetriever: + @pytest.fixture + def mock_retriever(self): + return ChunksRetriever() + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") + async def test_get_completion(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query" + doc_id1 = str(uuid.uuid4()) + doc_id2 = str(uuid.uuid4()) + + # Mock search results + mock_result_1 = MagicMock() + mock_result_1.payload = { + "id": str(uuid.uuid4()), + "text": "This is the first chunk result.", + "document_id": doc_id1, + "metadata": {"title": "Document 1"}, + } + + mock_result_2 = MagicMock() + mock_result_2.payload = { + "id": str(uuid.uuid4()), + "text": "This is the second chunk result.", + "document_id": doc_id2, + "metadata": {"title": "Document 2"}, + } + + mock_search_results = [mock_result_1, mock_result_2] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 2 + + # Check first result + assert results[0]["text"] == "This is the first chunk result." + assert results[0]["document_id"] == doc_id1 + assert results[0]["metadata"]["title"] == "Document 1" + + # Check second result + assert results[1]["text"] == "This is the second chunk result." + assert results[1]["document_id"] == doc_id2 + assert results[1]["metadata"]["title"] == "Document 2" + + # Verify search was called correctly + mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5) + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") + async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query with no results" + mock_search_results = [] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 0 + mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5) + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") + async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query with incomplete data" + + # Mock search results + mock_result_1 = MagicMock() + mock_result_1.payload = { + "id": str(uuid.uuid4()), + "text": "This chunk has no document_id.", + # Missing document_id and metadata + } + mock_result_2 = MagicMock() + mock_result_2.payload = { + "id": str(uuid.uuid4()), + # Missing text + "document_id": str(uuid.uuid4()), + "metadata": {"title": "Document with missing text"}, + } + + mock_search_results = [mock_result_1, mock_result_2] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 2 + + # First result should have content but no document_id + assert results[0]["text"] == "This chunk has no document_id." + assert "document_id" not in results[0] + assert "metadata" not in results[0] + + # Second result should have document_id and metadata but no content + assert "text" not in results[1] + assert "document_id" in results[1] + assert results[1]["metadata"]["title"] == "Document with missing text" diff --git a/cognee/tests/unit/modules/retrieval/completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/completion_retriever_test.py new file mode 100644 index 000000000..1eace3cf1 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/completion_retriever_test.py @@ -0,0 +1,82 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestCompletionRetriever: + @pytest.fixture + def mock_retriever(self): + return CompletionRetriever(system_prompt_path="test_prompt.txt") + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.utils.completion.get_llm_client") + @patch("cognee.modules.retrieval.utils.completion.render_prompt") + @patch("cognee.modules.retrieval.completion_retriever.get_vector_engine") + async def test_get_completion( + self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever + ): + # Setup + query = "test query" + + # Mock render_prompt + mock_render_prompt.return_value = "Rendered prompt with context" + + mock_search_results = [MagicMock()] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Mock LLM client + mock_llm_client = MagicMock() + mock_llm_client.acreate_structured_output = AsyncMock() + mock_llm_client.acreate_structured_output.return_value = "Generated completion response" + mock_get_llm_client.return_value = mock_llm_client + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 1 + assert results[0] == "Generated completion response" + + # Verify prompt was rendered + mock_render_prompt.assert_called_once() + + # Verify LLM client was called + mock_llm_client.acreate_structured_output.assert_called_once_with( + text_input="Rendered prompt with context", system_prompt=None, response_model=str + ) + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.completion_retriever.generate_completion") + @patch("cognee.modules.retrieval.completion_retriever.get_vector_engine") + async def test_get_completion_with_custom_prompt( + self, mock_get_vector_engine, mock_generate_completion, mock_retriever + ): + # Setup + query = "test query with custom prompt" + + mock_search_results = [MagicMock()] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + mock_retriever.user_prompt_path = "custom_user_prompt.txt" + mock_retriever.system_prompt_path = "custom_system_prompt.txt" + + mock_generate_completion.return_value = "Custom prompt completion response" + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 1 + assert results[0] == "Custom prompt completion response" + + assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt" + assert ( + mock_generate_completion.call_args[1]["system_prompt_path"] + == "custom_system_prompt.txt" + ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py new file mode 100644 index 000000000..acf6f4ece --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -0,0 +1,149 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.graph.exceptions import EntityNotFoundError +from cognee.tasks.completion.exceptions import NoRelevantDataFound + + +class TestGraphCompletionRetriever: + @pytest.fixture + def mock_retriever(self): + return GraphCompletionRetriever(system_prompt_path="test_prompt.txt") + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search") + async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever): + mock_brute_force_triplet_search.return_value = [ + AsyncMock( + node1=AsyncMock(attributes={"text": "Node A"}), + attributes={"relationship_type": "connects"}, + node2=AsyncMock(attributes={"text": "Node B"}), + ) + ] + + result = await mock_retriever.get_triplets("test query") + + assert isinstance(result, list) + assert len(result) > 0 + assert result[0].attributes["relationship_type"] == "connects" + mock_brute_force_triplet_search.assert_called_once() + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search") + async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever): + mock_brute_force_triplet_search.return_value = [] + + with pytest.raises(NoRelevantDataFound): + await mock_retriever.get_triplets("test query") + + @pytest.mark.asyncio + async def test_resolve_edges_to_text(self, mock_retriever): + triplets = [ + AsyncMock( + node1=AsyncMock(attributes={"text": "Node A"}), + attributes={"relationship_type": "connects"}, + node2=AsyncMock(attributes={"text": "Node B"}), + ), + AsyncMock( + node1=AsyncMock(attributes={"text": "Node X"}), + attributes={"relationship_type": "links"}, + node2=AsyncMock(attributes={"text": "Node Y"}), + ), + ] + + result = await mock_retriever.resolve_edges_to_text(triplets) + + expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y" + assert result == expected_output + + @pytest.mark.asyncio + @patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets", + new_callable=AsyncMock, + ) + @patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + ) + async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever): + """Test get_context calls get_triplets and resolve_edges_to_text.""" + mock_get_triplets.return_value = ["mock_triplet"] + mock_resolve_edges_to_text.return_value = "Mock Context" + + result = await mock_retriever.get_context("test query") + + assert result == "Mock Context" + mock_get_triplets.assert_called_once_with("test query") + mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"]) + + @pytest.mark.asyncio + @patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context" + ) + @patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion") + async def test_get_completion_without_context( + self, mock_generate_completion, mock_get_context, mock_retriever + ): + """Test get_completion when no context is provided (calls get_context).""" + mock_get_context.return_value = "Mock Context" + mock_generate_completion.return_value = "Generated Completion" + + result = await mock_retriever.get_completion("test query") + + assert result == ["Generated Completion"] + mock_get_context.assert_called_once_with("test query") + mock_generate_completion.assert_called_once() + + @pytest.mark.asyncio + @patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context" + ) + @patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion") + async def test_get_completion_with_context( + self, mock_generate_completion, mock_get_context, mock_retriever + ): + """Test get_completion when context is provided (does not call get_context).""" + mock_generate_completion.return_value = "Generated Completion" + + result = await mock_retriever.get_completion("test query", context="Provided Context") + + assert result == ["Generated Completion"] + mock_get_context.assert_not_called() + mock_generate_completion.assert_called_once() + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.utils.completion.get_llm_client") + @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine") + @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user") + async def test_get_completion_with_empty_graph( + self, + mock_get_default_user, + mock_get_graph_engine, + mock_get_llm_client, + mock_retriever, + ): + # Setup + query = "test query with empty graph" + + # Mock graph engine with empty graph + mock_graph_engine = MagicMock() + mock_graph_engine.get_graph_data = AsyncMock() + mock_graph_engine.get_graph_data.return_value = ([], []) + mock_get_graph_engine.return_value = mock_graph_engine + + # Mock LLM client + mock_llm_client = MagicMock() + mock_llm_client.acreate_structured_output = AsyncMock() + mock_llm_client.acreate_structured_output.return_value = ( + "Generated graph completion response" + ) + mock_get_llm_client.return_value = mock_llm_client + + # Execute + with pytest.raises(EntityNotFoundError): + await mock_retriever.get_completion(query) + + # Verify graph engine was called + mock_graph_engine.get_graph_data.assert_called_once() diff --git a/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py new file mode 100644 index 000000000..e35842d86 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py @@ -0,0 +1,80 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) + + +class TestGraphSummaryCompletionRetriever: + @pytest.fixture + def mock_retriever(self): + return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt") + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.utils.completion.get_llm_client") + @patch("cognee.modules.retrieval.utils.completion.read_query_prompt") + @patch("cognee.modules.retrieval.utils.completion.render_prompt") + @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user") + async def test_get_completion_with_custom_system_prompt( + self, + mock_get_default_user, + mock_render_prompt, + mock_read_query_prompt, + mock_get_llm_client, + mock_retriever, + ): + # Setup + query = "test query with custom prompt" + + # Set custom system prompt + mock_retriever.user_prompt_path = "custom_user_prompt.txt" + mock_retriever.system_prompt_path = "custom_system_prompt.txt" + + mock_llm_client = MagicMock() + mock_llm_client.acreate_structured_output = AsyncMock() + mock_llm_client.acreate_structured_output.return_value = ( + "Generated graph summary completion response" + ) + mock_get_llm_client.return_value = mock_llm_client + + # Execute + results = await mock_retriever.get_completion(query, context="test context") + + # Verify + assert len(results) == 1 + + # Verify render_prompt was called with custom prompt path + mock_render_prompt.assert_called_once() + assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt" + + mock_read_query_prompt.assert_called_once() + assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt" + + mock_llm_client.acreate_structured_output.assert_called_once() + + @pytest.mark.asyncio + @patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text" + ) + @patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + ) + async def test_resolve_edges_to_text_calls_super_and_summarizes( + self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever + ): + """Test resolve_edges_to_text calls the parent method and summarizes the result.""" + + mock_resolve_edges_to_text.return_value = "Raw graph edges text" + mock_summarize_text.return_value = "Summarized graph text" + + result = await mock_retriever.resolve_edges_to_text(["mock_edge"]) + + mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"]) + mock_summarize_text.assert_called_once_with( + "Raw graph edges text", mock_retriever.summarize_prompt_path + ) + + assert result == "Summarized graph text" diff --git a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py new file mode 100644 index 000000000..76a3506ad --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py @@ -0,0 +1,103 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.insights_retriever import InsightsRetriever +from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph +from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine +import unittest +from cognee.infrastructure.databases.graph import get_graph_engine + + +class TestInsightsRetriever: + @pytest.fixture + def mock_retriever(self): + return InsightsRetriever() + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") + async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever): + """Test get_context when node exists in graph.""" + mock_graph = AsyncMock() + mock_get_graph_engine.return_value = mock_graph + + # Mock graph response + mock_graph.extract_node.return_value = {"id": "123"} + mock_graph.get_connections.return_value = [ + ({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"}) + ] + + result = await mock_retriever.get_context("123") + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0][0]["id"] == "123" + assert result[0][1]["relationship_name"] == "linked_to" + assert result[0][2]["id"] == "456" + mock_graph.extract_node.assert_called_once_with("123") + mock_graph.get_connections.assert_called_once_with("123") + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.insights_retriever.get_vector_engine") + async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query with no results" + mock_search_results = [] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 0 + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") + @patch("cognee.modules.retrieval.insights_retriever.get_vector_engine") + async def test_get_context_with_no_exact_node( + self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever + ): + """Test get_context when node does not exist in the graph and vector search is used.""" + mock_graph = AsyncMock() + mock_get_graph_engine.return_value = mock_graph + mock_graph.extract_node.return_value = None # Node does not exist + + mock_vector = AsyncMock() + mock_get_vector_engine.return_value = mock_vector + + mock_vector.search.side_effect = [ + [AsyncMock(id="vec_1", score=0.4)], # Entity_name search + [AsyncMock(id="vec_2", score=0.3)], # EntityType_name search + ] + + mock_graph.get_connections.side_effect = lambda node_id: [ + ({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"}) + ] + + result = await mock_retriever.get_context("non_existing_query") + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0][0]["id"] == "vec_1" + assert result[0][1]["relationship_name"] == "related_to" + assert result[0][2]["id"] == "456" + + assert result[1][0]["id"] == "vec_2" + assert result[1][1]["relationship_name"] == "related_to" + assert result[1][2]["id"] == "456" + + @pytest.mark.asyncio + async def test_get_context_with_none_query(self, mock_retriever): + """Test get_context with a None query (should return empty list).""" + result = await mock_retriever.get_context(None) + assert result == [] + + @pytest.mark.asyncio + async def test_get_completion_with_context(self, mock_retriever): + """Test get_completion when context is already provided.""" + test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})] + result = await mock_retriever.get_completion("test_query", context=test_context) + assert result == test_context diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py new file mode 100644 index 000000000..f62d81292 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -0,0 +1,122 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + +class TestSummariesRetriever: + @pytest.fixture + def mock_retriever(self): + return SummariesRetriever() + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") + async def test_get_completion(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query" + doc_id1 = str(uuid.uuid4()) + doc_id2 = str(uuid.uuid4()) + + # Mock search results + mock_result_1 = MagicMock() + mock_result_1.payload = { + "id": str(uuid.uuid4()), + "score": 0.95, + "payload": { + "text": "This is the first summary.", + "document_id": doc_id1, + "metadata": {"title": "Document 1"}, + }, + } + mock_result_2 = MagicMock() + mock_result_2.payload = { + "id": str(uuid.uuid4()), + "score": 0.85, + "payload": { + "text": "This is the second summary.", + "document_id": doc_id2, + "metadata": {"title": "Document 2"}, + }, + } + + mock_search_results = [mock_result_1, mock_result_2] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 2 + + # Check first result + assert results[0]["payload"]["text"] == "This is the first summary." + assert results[0]["payload"]["document_id"] == doc_id1 + assert results[0]["payload"]["metadata"]["title"] == "Document 1" + assert results[0]["score"] == 0.95 + + # Check second result + assert results[1]["payload"]["text"] == "This is the second summary." + assert results[1]["payload"]["document_id"] == doc_id2 + assert results[1]["payload"]["metadata"]["title"] == "Document 2" + assert results[1]["score"] == 0.85 + + # Verify search was called correctly + mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5) + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") + async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query with no results" + mock_search_results = [] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 0 + mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5) + + @pytest.mark.asyncio + @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") + async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever): + # Setup + query = "test query with custom limit" + doc_id = str(uuid.uuid4()) + + # Mock search results + mock_result = MagicMock() + mock_result.payload = { + "id": str(uuid.uuid4()), + "score": 0.95, + "payload": { + "text": "This is a summary.", + "document_id": doc_id, + "metadata": {"title": "Document 1"}, + }, + } + + mock_search_results = [mock_result] + mock_vector_engine = AsyncMock() + mock_vector_engine.search.return_value = mock_search_results + mock_get_vector_engine.return_value = mock_vector_engine + + # Set custom limit + mock_retriever.limit = 10 + + # Execute + results = await mock_retriever.get_completion(query) + + # Verify + assert len(results) == 1 + assert results[0]["payload"]["text"] == "This is a summary." + + # Verify search was called with custom limit + mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py new file mode 100644 index 000000000..a077178f8 --- /dev/null +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -0,0 +1,177 @@ +import json +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cognee.exceptions import InvalidValueError +from cognee.modules.search.methods.search import search, specific_search +from cognee.modules.search.types import SearchType +from cognee.modules.users.models import User +import sys + +search_module = sys.modules.get("cognee.modules.search.methods.search") + + +@pytest.fixture +def mock_user(): + user = MagicMock(spec=User) + user.id = uuid.uuid4() + return user + + +@pytest.mark.asyncio +@patch.object(search_module, "log_query") +@patch.object(search_module, "log_result") +@patch.object(search_module, "get_document_ids_for_user") +@patch.object(search_module, "specific_search") +@patch.object(search_module, "parse_id") +async def test_search( + mock_parse_id, + mock_specific_search, + mock_get_document_ids, + mock_log_result, + mock_log_query, + mock_user, +): + # Setup + query_text = "test query" + query_type = SearchType.CHUNKS + datasets = ["dataset1", "dataset2"] + + # Mock the query logging + mock_query = MagicMock() + mock_query.id = uuid.uuid4() + mock_log_query.return_value = mock_query + + # Mock document IDs + doc_id1 = uuid.uuid4() + doc_id2 = uuid.uuid4() + doc_id3 = uuid.uuid4() # This one will be filtered out + mock_get_document_ids.return_value = [doc_id1, doc_id2] + + # Mock search results + search_results = [ + {"document_id": str(doc_id1), "content": "Result 1"}, + {"document_id": str(doc_id2), "content": "Result 2"}, + {"document_id": str(doc_id3), "content": "Result 3"}, # Should be filtered out + ] + mock_specific_search.return_value = search_results + + # Mock parse_id to return the same UUID + mock_parse_id.side_effect = lambda x: uuid.UUID(x) if x else None + + # Execute + results = await search(query_text, query_type, datasets, mock_user) + + # Verify + mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) + mock_get_document_ids.assert_called_once_with(mock_user.id, datasets) + mock_specific_search.assert_called_once_with( + query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt" + ) + + # Only the first two results should be included (doc_id3 is filtered out) + assert len(results) == 2 + assert results[0]["document_id"] == str(doc_id1) + assert results[1]["document_id"] == str(doc_id2) + + # Verify result logging + mock_log_result.assert_called_once() + # Check that the first argument is the query ID + assert mock_log_result.call_args[0][0] == mock_query.id + # The second argument should be the JSON string of the filtered results + # We can't directly compare the JSON strings due to potential ordering differences + # So we parse the JSON and compare the objects + logged_results = json.loads(mock_log_result.call_args[0][1]) + assert len(logged_results) == 2 + assert logged_results[0]["document_id"] == str(doc_id1) + assert logged_results[1]["document_id"] == str(doc_id2) + + +@pytest.mark.asyncio +@patch.object(search_module, "SummariesRetriever") +@patch.object(search_module, "send_telemetry") +async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user): + # Setup + query = "test query" + query_type = SearchType.SUMMARIES + + # Mock the retriever + mock_retriever = MagicMock() + mock_retriever.get_completion = AsyncMock() + mock_retriever.get_completion.return_value = [{"content": "Summary result"}] + mock_summaries_retriever.return_value = mock_retriever + + # Execute + results = await specific_search(query_type, query, mock_user) + + # Verify + mock_summaries_retriever.assert_called_once() + mock_retriever.get_completion.assert_called_once_with(query) + mock_send_telemetry.assert_called() + assert len(results) == 1 + assert results[0]["content"] == "Summary result" + + +@pytest.mark.asyncio +@patch.object(search_module, "InsightsRetriever") +@patch.object(search_module, "send_telemetry") +async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user): + # Setup + query = "test query" + query_type = SearchType.INSIGHTS + + # Mock the retriever + mock_retriever = MagicMock() + mock_retriever.get_completion = AsyncMock() + mock_retriever.get_completion.return_value = [{"content": "Insight result"}] + mock_insights_retriever.return_value = mock_retriever + + # Execute + results = await specific_search(query_type, query, mock_user) + + # Verify + mock_insights_retriever.assert_called_once() + mock_retriever.get_completion.assert_called_once_with(query) + mock_send_telemetry.assert_called() + assert len(results) == 1 + assert results[0]["content"] == "Insight result" + + +@pytest.mark.asyncio +@patch.object(search_module, "ChunksRetriever") +@patch.object(search_module, "send_telemetry") +async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user): + # Setup + query = "test query" + query_type = SearchType.CHUNKS + + # Mock the retriever + mock_retriever = MagicMock() + mock_retriever.get_completion = AsyncMock() + mock_retriever.get_completion.return_value = [{"content": "Chunk result"}] + mock_chunks_retriever.return_value = mock_retriever + + # Execute + results = await specific_search(query_type, query, mock_user) + + # Verify + mock_chunks_retriever.assert_called_once() + mock_retriever.get_completion.assert_called_once_with(query) + mock_send_telemetry.assert_called() + assert len(results) == 1 + assert results[0]["content"] == "Chunk result" + + +@pytest.mark.asyncio +async def test_specific_search_invalid_type(mock_user): + # Setup + query = "test query" + query_type = "INVALID_TYPE" # Not a valid SearchType + + # Execute and verify + with pytest.raises(InvalidValueError) as excinfo: + await specific_search(query_type, query, mock_user) + + assert "Unsupported search type" in str(excinfo.value)