test: test retrievers [cog-1433] (#635)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Removed unused code to streamline internal processes. - **Tests** - Added a comprehensive suite of tests to validate core retrieval and search functionalities. - Improved validation of response generation, context handling, and error scenarios to ensure consistent and reliable performance. These improvements enhance overall system stability and maintainability, contributing to a smoother experience for end-users. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: vasilije <vas.markovic@gmail.com>
This commit is contained in:
parent
ede344be5d
commit
164cb581ec
8 changed files with 833 additions and 2 deletions
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
||||
|
|
|
|||
120
cognee/tests/unit/modules/retrieval/chunks_retriever_test.py
Normal file
120
cognee/tests/unit/modules/retrieval/chunks_retriever_test.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
|
||||
|
||||
class TestChunksRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return ChunksRetriever()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query"
|
||||
doc_id1 = str(uuid.uuid4())
|
||||
doc_id2 = str(uuid.uuid4())
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This is the first chunk result.",
|
||||
"document_id": doc_id1,
|
||||
"metadata": {"title": "Document 1"},
|
||||
}
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This is the second chunk result.",
|
||||
"document_id": doc_id2,
|
||||
"metadata": {"title": "Document 2"},
|
||||
}
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
|
||||
# Check first result
|
||||
assert results[0]["text"] == "This is the first chunk result."
|
||||
assert results[0]["document_id"] == doc_id1
|
||||
assert results[0]["metadata"]["title"] == "Document 1"
|
||||
|
||||
# Check second result
|
||||
assert results[1]["text"] == "This is the second chunk result."
|
||||
assert results[1]["document_id"] == doc_id2
|
||||
assert results[1]["metadata"]["title"] == "Document 2"
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with incomplete data"
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This chunk has no document_id.",
|
||||
# Missing document_id and metadata
|
||||
}
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
# Missing text
|
||||
"document_id": str(uuid.uuid4()),
|
||||
"metadata": {"title": "Document with missing text"},
|
||||
}
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
|
||||
# First result should have content but no document_id
|
||||
assert results[0]["text"] == "This chunk has no document_id."
|
||||
assert "document_id" not in results[0]
|
||||
assert "metadata" not in results[0]
|
||||
|
||||
# Second result should have document_id and metadata but no content
|
||||
assert "text" not in results[1]
|
||||
assert "document_id" in results[1]
|
||||
assert results[1]["metadata"]["title"] == "Document with missing text"
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
|
||||
|
||||
class TestCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return CompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
||||
async def test_get_completion(
|
||||
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
|
||||
):
|
||||
# Setup
|
||||
query = "test query"
|
||||
|
||||
# Mock render_prompt
|
||||
mock_render_prompt.return_value = "Rendered prompt with context"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Mock LLM client
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0] == "Generated completion response"
|
||||
|
||||
# Verify prompt was rendered
|
||||
mock_render_prompt.assert_called_once()
|
||||
|
||||
# Verify LLM client was called
|
||||
mock_llm_client.acreate_structured_output.assert_called_once_with(
|
||||
text_input="Rendered prompt with context", system_prompt=None, response_model=str
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
|
||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_custom_prompt(
|
||||
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
|
||||
):
|
||||
# Setup
|
||||
query = "test query with custom prompt"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
||||
|
||||
mock_generate_completion.return_value = "Custom prompt completion response"
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0] == "Custom prompt completion response"
|
||||
|
||||
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
|
||||
assert (
|
||||
mock_generate_completion.call_args[1]["system_prompt_path"]
|
||||
== "custom_system_prompt.txt"
|
||||
)
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
||||
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever):
|
||||
mock_brute_force_triplet_search.return_value = [
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node A"}),
|
||||
attributes={"relationship_type": "connects"},
|
||||
node2=AsyncMock(attributes={"text": "Node B"}),
|
||||
)
|
||||
]
|
||||
|
||||
result = await mock_retriever.get_triplets("test query")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert result[0].attributes["relationship_type"] == "connects"
|
||||
mock_brute_force_triplet_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
||||
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
|
||||
mock_brute_force_triplet_search.return_value = []
|
||||
|
||||
with pytest.raises(NoRelevantDataFound):
|
||||
await mock_retriever.get_triplets("test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text(self, mock_retriever):
|
||||
triplets = [
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node A"}),
|
||||
attributes={"relationship_type": "connects"},
|
||||
node2=AsyncMock(attributes={"text": "Node B"}),
|
||||
),
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node X"}),
|
||||
attributes={"relationship_type": "links"},
|
||||
node2=AsyncMock(attributes={"text": "Node Y"}),
|
||||
),
|
||||
]
|
||||
|
||||
result = await mock_retriever.resolve_edges_to_text(triplets)
|
||||
|
||||
expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y"
|
||||
assert result == expected_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
|
||||
"""Test get_context calls get_triplets and resolve_edges_to_text."""
|
||||
mock_get_triplets.return_value = ["mock_triplet"]
|
||||
mock_resolve_edges_to_text.return_value = "Mock Context"
|
||||
|
||||
result = await mock_retriever.get_context("test query")
|
||||
|
||||
assert result == "Mock Context"
|
||||
mock_get_triplets.assert_called_once_with("test query")
|
||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
||||
)
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
||||
async def test_get_completion_without_context(
|
||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
||||
):
|
||||
"""Test get_completion when no context is provided (calls get_context)."""
|
||||
mock_get_context.return_value = "Mock Context"
|
||||
mock_generate_completion.return_value = "Generated Completion"
|
||||
|
||||
result = await mock_retriever.get_completion("test query")
|
||||
|
||||
assert result == ["Generated Completion"]
|
||||
mock_get_context.assert_called_once_with("test query")
|
||||
mock_generate_completion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
||||
)
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
||||
async def test_get_completion_with_context(
|
||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
||||
):
|
||||
"""Test get_completion when context is provided (does not call get_context)."""
|
||||
mock_generate_completion.return_value = "Generated Completion"
|
||||
|
||||
result = await mock_retriever.get_completion("test query", context="Provided Context")
|
||||
|
||||
assert result == ["Generated Completion"]
|
||||
mock_get_context.assert_not_called()
|
||||
mock_generate_completion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
||||
async def test_get_completion_with_empty_graph(
|
||||
self,
|
||||
mock_get_default_user,
|
||||
mock_get_graph_engine,
|
||||
mock_get_llm_client,
|
||||
mock_retriever,
|
||||
):
|
||||
# Setup
|
||||
query = "test query with empty graph"
|
||||
|
||||
# Mock graph engine with empty graph
|
||||
mock_graph_engine = MagicMock()
|
||||
mock_graph_engine.get_graph_data = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = ([], [])
|
||||
mock_get_graph_engine.return_value = mock_graph_engine
|
||||
|
||||
# Mock LLM client
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = (
|
||||
"Generated graph completion response"
|
||||
)
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
with pytest.raises(EntityNotFoundError):
|
||||
await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify graph engine was called
|
||||
mock_graph_engine.get_graph_data.assert_called_once()
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphSummaryCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
|
||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
||||
async def test_get_completion_with_custom_system_prompt(
|
||||
self,
|
||||
mock_get_default_user,
|
||||
mock_render_prompt,
|
||||
mock_read_query_prompt,
|
||||
mock_get_llm_client,
|
||||
mock_retriever,
|
||||
):
|
||||
# Setup
|
||||
query = "test query with custom prompt"
|
||||
|
||||
# Set custom system prompt
|
||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
||||
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = (
|
||||
"Generated graph summary completion response"
|
||||
)
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query, context="test context")
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
|
||||
# Verify render_prompt was called with custom prompt path
|
||||
mock_render_prompt.assert_called_once()
|
||||
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
|
||||
|
||||
mock_read_query_prompt.assert_called_once()
|
||||
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
|
||||
|
||||
mock_llm_client.acreate_structured_output.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
|
||||
)
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_resolve_edges_to_text_calls_super_and_summarizes(
|
||||
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
|
||||
):
|
||||
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
|
||||
|
||||
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
|
||||
mock_summarize_text.return_value = "Summarized graph text"
|
||||
|
||||
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
|
||||
|
||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
|
||||
mock_summarize_text.assert_called_once_with(
|
||||
"Raw graph edges text", mock_retriever.summarize_prompt_path
|
||||
)
|
||||
|
||||
assert result == "Summarized graph text"
|
||||
103
cognee/tests/unit/modules/retrieval/insights_retriever_test.py
Normal file
103
cognee/tests/unit/modules/retrieval/insights_retriever_test.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
import unittest
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
||||
class TestInsightsRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return InsightsRetriever()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
||||
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever):
|
||||
"""Test get_context when node exists in graph."""
|
||||
mock_graph = AsyncMock()
|
||||
mock_get_graph_engine.return_value = mock_graph
|
||||
|
||||
# Mock graph response
|
||||
mock_graph.extract_node.return_value = {"id": "123"}
|
||||
mock_graph.get_connections.return_value = [
|
||||
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
|
||||
]
|
||||
|
||||
result = await mock_retriever.get_context("123")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0][0]["id"] == "123"
|
||||
assert result[0][1]["relationship_name"] == "linked_to"
|
||||
assert result[0][2]["id"] == "456"
|
||||
mock_graph.extract_node.assert_called_once_with("123")
|
||||
mock_graph.get_connections.assert_called_once_with("123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
||||
async def test_get_context_with_no_exact_node(
|
||||
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
|
||||
):
|
||||
"""Test get_context when node does not exist in the graph and vector search is used."""
|
||||
mock_graph = AsyncMock()
|
||||
mock_get_graph_engine.return_value = mock_graph
|
||||
mock_graph.extract_node.return_value = None # Node does not exist
|
||||
|
||||
mock_vector = AsyncMock()
|
||||
mock_get_vector_engine.return_value = mock_vector
|
||||
|
||||
mock_vector.search.side_effect = [
|
||||
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search
|
||||
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
|
||||
]
|
||||
|
||||
mock_graph.get_connections.side_effect = lambda node_id: [
|
||||
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
|
||||
]
|
||||
|
||||
result = await mock_retriever.get_context("non_existing_query")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0][0]["id"] == "vec_1"
|
||||
assert result[0][1]["relationship_name"] == "related_to"
|
||||
assert result[0][2]["id"] == "456"
|
||||
|
||||
assert result[1][0]["id"] == "vec_2"
|
||||
assert result[1][1]["relationship_name"] == "related_to"
|
||||
assert result[1][2]["id"] == "456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_with_none_query(self, mock_retriever):
|
||||
"""Test get_context with a None query (should return empty list)."""
|
||||
result = await mock_retriever.get_context(None)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_context(self, mock_retriever):
|
||||
"""Test get_completion when context is already provided."""
|
||||
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
|
||||
result = await mock_retriever.get_completion("test_query", context=test_context)
|
||||
assert result == test_context
|
||||
122
cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
Normal file
122
cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
class TestSummariesRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return SummariesRetriever()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query"
|
||||
doc_id1 = str(uuid.uuid4())
|
||||
doc_id2 = str(uuid.uuid4())
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.95,
|
||||
"payload": {
|
||||
"text": "This is the first summary.",
|
||||
"document_id": doc_id1,
|
||||
"metadata": {"title": "Document 1"},
|
||||
},
|
||||
}
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.85,
|
||||
"payload": {
|
||||
"text": "This is the second summary.",
|
||||
"document_id": doc_id2,
|
||||
"metadata": {"title": "Document 2"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
|
||||
# Check first result
|
||||
assert results[0]["payload"]["text"] == "This is the first summary."
|
||||
assert results[0]["payload"]["document_id"] == doc_id1
|
||||
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
|
||||
assert results[0]["score"] == 0.95
|
||||
|
||||
# Check second result
|
||||
assert results[1]["payload"]["text"] == "This is the second summary."
|
||||
assert results[1]["payload"]["document_id"] == doc_id2
|
||||
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
|
||||
assert results[1]["score"] == 0.85
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with custom limit"
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Mock search results
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.95,
|
||||
"payload": {
|
||||
"text": "This is a summary.",
|
||||
"document_id": doc_id,
|
||||
"metadata": {"title": "Document 1"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_search_results = [mock_result]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Set custom limit
|
||||
mock_retriever.limit = 10
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0]["payload"]["text"] == "This is a summary."
|
||||
|
||||
# Verify search was called with custom limit
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)
|
||||
177
cognee/tests/unit/modules/search/search_methods_test.py
Normal file
177
cognee/tests/unit/modules/search/search_methods_test.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.methods.search import search, specific_search
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.models import User
|
||||
import sys
|
||||
|
||||
search_module = sys.modules.get("cognee.modules.search.methods.search")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock(spec=User)
|
||||
user.id = uuid.uuid4()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "log_query")
|
||||
@patch.object(search_module, "log_result")
|
||||
@patch.object(search_module, "get_document_ids_for_user")
|
||||
@patch.object(search_module, "specific_search")
|
||||
@patch.object(search_module, "parse_id")
|
||||
async def test_search(
|
||||
mock_parse_id,
|
||||
mock_specific_search,
|
||||
mock_get_document_ids,
|
||||
mock_log_result,
|
||||
mock_log_query,
|
||||
mock_user,
|
||||
):
|
||||
# Setup
|
||||
query_text = "test query"
|
||||
query_type = SearchType.CHUNKS
|
||||
datasets = ["dataset1", "dataset2"]
|
||||
|
||||
# Mock the query logging
|
||||
mock_query = MagicMock()
|
||||
mock_query.id = uuid.uuid4()
|
||||
mock_log_query.return_value = mock_query
|
||||
|
||||
# Mock document IDs
|
||||
doc_id1 = uuid.uuid4()
|
||||
doc_id2 = uuid.uuid4()
|
||||
doc_id3 = uuid.uuid4() # This one will be filtered out
|
||||
mock_get_document_ids.return_value = [doc_id1, doc_id2]
|
||||
|
||||
# Mock search results
|
||||
search_results = [
|
||||
{"document_id": str(doc_id1), "content": "Result 1"},
|
||||
{"document_id": str(doc_id2), "content": "Result 2"},
|
||||
{"document_id": str(doc_id3), "content": "Result 3"}, # Should be filtered out
|
||||
]
|
||||
mock_specific_search.return_value = search_results
|
||||
|
||||
# Mock parse_id to return the same UUID
|
||||
mock_parse_id.side_effect = lambda x: uuid.UUID(x) if x else None
|
||||
|
||||
# Execute
|
||||
results = await search(query_text, query_type, datasets, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
|
||||
mock_specific_search.assert_called_once_with(
|
||||
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt"
|
||||
)
|
||||
|
||||
# Only the first two results should be included (doc_id3 is filtered out)
|
||||
assert len(results) == 2
|
||||
assert results[0]["document_id"] == str(doc_id1)
|
||||
assert results[1]["document_id"] == str(doc_id2)
|
||||
|
||||
# Verify result logging
|
||||
mock_log_result.assert_called_once()
|
||||
# Check that the first argument is the query ID
|
||||
assert mock_log_result.call_args[0][0] == mock_query.id
|
||||
# The second argument should be the JSON string of the filtered results
|
||||
# We can't directly compare the JSON strings due to potential ordering differences
|
||||
# So we parse the JSON and compare the objects
|
||||
logged_results = json.loads(mock_log_result.call_args[0][1])
|
||||
assert len(logged_results) == 2
|
||||
assert logged_results[0]["document_id"] == str(doc_id1)
|
||||
assert logged_results[1]["document_id"] == str(doc_id2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "SummariesRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.SUMMARIES
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Summary result"}]
|
||||
mock_summaries_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_summaries_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Summary result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "InsightsRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.INSIGHTS
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Insight result"}]
|
||||
mock_insights_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_insights_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Insight result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "ChunksRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.CHUNKS
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Chunk result"}]
|
||||
mock_chunks_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_chunks_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Chunk result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_specific_search_invalid_type(mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = "INVALID_TYPE" # Not a valid SearchType
|
||||
|
||||
# Execute and verify
|
||||
with pytest.raises(InvalidValueError) as excinfo:
|
||||
await specific_search(query_type, query, mock_user)
|
||||
|
||||
assert "Unsupported search type" in str(excinfo.value)
|
||||
Loading…
Add table
Reference in a new issue