test: test retrievers [cog-1433] (#635)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
	- Removed unused code to streamline internal processes.
  
- **Tests**
- Added a comprehensive suite of tests to validate core retrieval and
search functionalities.
- Improved validation of response generation, context handling, and
error scenarios to ensure consistent and reliable performance.

These improvements enhance overall system stability and maintainability,
contributing to a smoother experience for end-users.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: vasilije <vas.markovic@gmail.com>
This commit is contained in:
alekszievr 2025-03-20 10:18:21 +01:00 committed by GitHub
parent ede344be5d
commit 164cb581ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 833 additions and 2 deletions

View file

@ -1,5 +1,3 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt

View file

@ -0,0 +1,120 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
class TestChunksRetriever:
@pytest.fixture
def mock_retriever(self):
return ChunksRetriever()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query"
doc_id1 = str(uuid.uuid4())
doc_id2 = str(uuid.uuid4())
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"text": "This is the first chunk result.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
"text": "This is the second chunk result.",
"document_id": doc_id2,
"metadata": {"title": "Document 2"},
}
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 2
# Check first result
assert results[0]["text"] == "This is the first chunk result."
assert results[0]["document_id"] == doc_id1
assert results[0]["metadata"]["title"] == "Document 1"
# Check second result
assert results[1]["text"] == "This is the second chunk result."
assert results[1]["document_id"] == doc_id2
assert results[1]["metadata"]["title"] == "Document 2"
# Verify search was called correctly
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with incomplete data"
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"text": "This chunk has no document_id.",
# Missing document_id and metadata
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
# Missing text
"document_id": str(uuid.uuid4()),
"metadata": {"title": "Document with missing text"},
}
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 2
# First result should have content but no document_id
assert results[0]["text"] == "This chunk has no document_id."
assert "document_id" not in results[0]
assert "metadata" not in results[0]
# Second result should have document_id and metadata but no content
assert "text" not in results[1]
assert "document_id" in results[1]
assert results[1]["metadata"]["title"] == "Document with missing text"

View file

@ -0,0 +1,82 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return CompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion(
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
):
# Setup
query = "test query"
# Mock render_prompt
mock_render_prompt.return_value = "Rendered prompt with context"
mock_search_results = [MagicMock()]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Mock LLM client
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Generated completion response"
# Verify prompt was rendered
mock_render_prompt.assert_called_once()
# Verify LLM client was called
mock_llm_client.acreate_structured_output.assert_called_once_with(
text_input="Rendered prompt with context", system_prompt=None, response_model=str
)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion_with_custom_prompt(
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
):
# Setup
query = "test query with custom prompt"
mock_search_results = [MagicMock()]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_generate_completion.return_value = "Custom prompt completion response"
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Custom prompt completion response"
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
assert (
mock_generate_completion.call_args[1]["system_prompt_path"]
== "custom_system_prompt.txt"
)

View file

@ -0,0 +1,149 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.graph.exceptions import EntityNotFoundError
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class TestGraphCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever):
mock_brute_force_triplet_search.return_value = [
AsyncMock(
node1=AsyncMock(attributes={"text": "Node A"}),
attributes={"relationship_type": "connects"},
node2=AsyncMock(attributes={"text": "Node B"}),
)
]
result = await mock_retriever.get_triplets("test query")
assert isinstance(result, list)
assert len(result) > 0
assert result[0].attributes["relationship_type"] == "connects"
mock_brute_force_triplet_search.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
mock_brute_force_triplet_search.return_value = []
with pytest.raises(NoRelevantDataFound):
await mock_retriever.get_triplets("test query")
@pytest.mark.asyncio
async def test_resolve_edges_to_text(self, mock_retriever):
triplets = [
AsyncMock(
node1=AsyncMock(attributes={"text": "Node A"}),
attributes={"relationship_type": "connects"},
node2=AsyncMock(attributes={"text": "Node B"}),
),
AsyncMock(
node1=AsyncMock(attributes={"text": "Node X"}),
attributes={"relationship_type": "links"},
node2=AsyncMock(attributes={"text": "Node Y"}),
),
]
result = await mock_retriever.resolve_edges_to_text(triplets)
expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y"
assert result == expected_output
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
new_callable=AsyncMock,
)
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
)
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
"""Test get_context calls get_triplets and resolve_edges_to_text."""
mock_get_triplets.return_value = ["mock_triplet"]
mock_resolve_edges_to_text.return_value = "Mock Context"
result = await mock_retriever.get_context("test query")
assert result == "Mock Context"
mock_get_triplets.assert_called_once_with("test query")
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_without_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when no context is provided (calls get_context)."""
mock_get_context.return_value = "Mock Context"
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query")
assert result == ["Generated Completion"]
mock_get_context.assert_called_once_with("test query")
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_with_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when context is provided (does not call get_context)."""
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query", context="Provided Context")
assert result == ["Generated Completion"]
mock_get_context.assert_not_called()
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_empty_graph(
self,
mock_get_default_user,
mock_get_graph_engine,
mock_get_llm_client,
mock_retriever,
):
# Setup
query = "test query with empty graph"
# Mock graph engine with empty graph
mock_graph_engine = MagicMock()
mock_graph_engine.get_graph_data = AsyncMock()
mock_graph_engine.get_graph_data.return_value = ([], [])
mock_get_graph_engine.return_value = mock_graph_engine
# Mock LLM client
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph completion response"
)
mock_get_llm_client.return_value = mock_llm_client
# Execute
with pytest.raises(EntityNotFoundError):
await mock_retriever.get_completion(query)
# Verify graph engine was called
mock_graph_engine.get_graph_data.assert_called_once()

View file

@ -0,0 +1,80 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
class TestGraphSummaryCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_custom_system_prompt(
self,
mock_get_default_user,
mock_render_prompt,
mock_read_query_prompt,
mock_get_llm_client,
mock_retriever,
):
# Setup
query = "test query with custom prompt"
# Set custom system prompt
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph summary completion response"
)
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query, context="test context")
# Verify
assert len(results) == 1
# Verify render_prompt was called with custom prompt path
mock_render_prompt.assert_called_once()
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
mock_read_query_prompt.assert_called_once()
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
mock_llm_client.acreate_structured_output.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
)
@patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
)
async def test_resolve_edges_to_text_calls_super_and_summarizes(
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
):
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
mock_summarize_text.return_value = "Summarized graph text"
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
mock_summarize_text.assert_called_once_with(
"Raw graph edges text", mock_retriever.summarize_prompt_path
)
assert result == "Summarized graph text"

View file

@ -0,0 +1,103 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
import unittest
from cognee.infrastructure.databases.graph import get_graph_engine
class TestInsightsRetriever:
@pytest.fixture
def mock_retriever(self):
return InsightsRetriever()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever):
"""Test get_context when node exists in graph."""
mock_graph = AsyncMock()
mock_get_graph_engine.return_value = mock_graph
# Mock graph response
mock_graph.extract_node.return_value = {"id": "123"}
mock_graph.get_connections.return_value = [
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
]
result = await mock_retriever.get_context("123")
assert isinstance(result, list)
assert len(result) == 1
assert result[0][0]["id"] == "123"
assert result[0][1]["relationship_name"] == "linked_to"
assert result[0][2]["id"] == "456"
mock_graph.extract_node.assert_called_once_with("123")
mock_graph.get_connections.assert_called_once_with("123")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 0
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
async def test_get_context_with_no_exact_node(
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
):
"""Test get_context when node does not exist in the graph and vector search is used."""
mock_graph = AsyncMock()
mock_get_graph_engine.return_value = mock_graph
mock_graph.extract_node.return_value = None # Node does not exist
mock_vector = AsyncMock()
mock_get_vector_engine.return_value = mock_vector
mock_vector.search.side_effect = [
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
]
mock_graph.get_connections.side_effect = lambda node_id: [
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
]
result = await mock_retriever.get_context("non_existing_query")
assert isinstance(result, list)
assert len(result) == 2
assert result[0][0]["id"] == "vec_1"
assert result[0][1]["relationship_name"] == "related_to"
assert result[0][2]["id"] == "456"
assert result[1][0]["id"] == "vec_2"
assert result[1][1]["relationship_name"] == "related_to"
assert result[1][2]["id"] == "456"
@pytest.mark.asyncio
async def test_get_context_with_none_query(self, mock_retriever):
"""Test get_context with a None query (should return empty list)."""
result = await mock_retriever.get_context(None)
assert result == []
@pytest.mark.asyncio
async def test_get_completion_with_context(self, mock_retriever):
"""Test get_completion when context is already provided."""
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
result = await mock_retriever.get_completion("test_query", context=test_context)
assert result == test_context

View file

@ -0,0 +1,122 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TestSummariesRetriever:
@pytest.fixture
def mock_retriever(self):
return SummariesRetriever()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query"
doc_id1 = str(uuid.uuid4())
doc_id2 = str(uuid.uuid4())
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is the first summary.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
},
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
"score": 0.85,
"payload": {
"text": "This is the second summary.",
"document_id": doc_id2,
"metadata": {"title": "Document 2"},
},
}
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 2
# Check first result
assert results[0]["payload"]["text"] == "This is the first summary."
assert results[0]["payload"]["document_id"] == doc_id1
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
assert results[0]["score"] == 0.95
# Check second result
assert results[1]["payload"]["text"] == "This is the second summary."
assert results[1]["payload"]["document_id"] == doc_id2
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
assert results[1]["score"] == 0.85
# Verify search was called correctly
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with custom limit"
doc_id = str(uuid.uuid4())
# Mock search results
mock_result = MagicMock()
mock_result.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is a summary.",
"document_id": doc_id,
"metadata": {"title": "Document 1"},
},
}
mock_search_results = [mock_result]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Set custom limit
mock_retriever.limit = 10
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0]["payload"]["text"] == "This is a summary."
# Verify search was called with custom limit
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)

View file

@ -0,0 +1,177 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.exceptions import InvalidValueError
from cognee.modules.search.methods.search import search, specific_search
from cognee.modules.search.types import SearchType
from cognee.modules.users.models import User
import sys
search_module = sys.modules.get("cognee.modules.search.methods.search")
@pytest.fixture
def mock_user():
user = MagicMock(spec=User)
user.id = uuid.uuid4()
return user
@pytest.mark.asyncio
@patch.object(search_module, "log_query")
@patch.object(search_module, "log_result")
@patch.object(search_module, "get_document_ids_for_user")
@patch.object(search_module, "specific_search")
@patch.object(search_module, "parse_id")
async def test_search(
mock_parse_id,
mock_specific_search,
mock_get_document_ids,
mock_log_result,
mock_log_query,
mock_user,
):
# Setup
query_text = "test query"
query_type = SearchType.CHUNKS
datasets = ["dataset1", "dataset2"]
# Mock the query logging
mock_query = MagicMock()
mock_query.id = uuid.uuid4()
mock_log_query.return_value = mock_query
# Mock document IDs
doc_id1 = uuid.uuid4()
doc_id2 = uuid.uuid4()
doc_id3 = uuid.uuid4() # This one will be filtered out
mock_get_document_ids.return_value = [doc_id1, doc_id2]
# Mock search results
search_results = [
{"document_id": str(doc_id1), "content": "Result 1"},
{"document_id": str(doc_id2), "content": "Result 2"},
{"document_id": str(doc_id3), "content": "Result 3"}, # Should be filtered out
]
mock_specific_search.return_value = search_results
# Mock parse_id to return the same UUID
mock_parse_id.side_effect = lambda x: uuid.UUID(x) if x else None
# Execute
results = await search(query_text, query_type, datasets, mock_user)
# Verify
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
mock_specific_search.assert_called_once_with(
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt"
)
# Only the first two results should be included (doc_id3 is filtered out)
assert len(results) == 2
assert results[0]["document_id"] == str(doc_id1)
assert results[1]["document_id"] == str(doc_id2)
# Verify result logging
mock_log_result.assert_called_once()
# Check that the first argument is the query ID
assert mock_log_result.call_args[0][0] == mock_query.id
# The second argument should be the JSON string of the filtered results
# We can't directly compare the JSON strings due to potential ordering differences
# So we parse the JSON and compare the objects
logged_results = json.loads(mock_log_result.call_args[0][1])
assert len(logged_results) == 2
assert logged_results[0]["document_id"] == str(doc_id1)
assert logged_results[1]["document_id"] == str(doc_id2)
@pytest.mark.asyncio
@patch.object(search_module, "SummariesRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.SUMMARIES
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Summary result"}]
mock_summaries_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_summaries_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Summary result"
@pytest.mark.asyncio
@patch.object(search_module, "InsightsRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.INSIGHTS
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Insight result"}]
mock_insights_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_insights_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Insight result"
@pytest.mark.asyncio
@patch.object(search_module, "ChunksRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.CHUNKS
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Chunk result"}]
mock_chunks_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_chunks_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Chunk result"
@pytest.mark.asyncio
async def test_specific_search_invalid_type(mock_user):
# Setup
query = "test query"
query_type = "INVALID_TYPE" # Not a valid SearchType
# Execute and verify
with pytest.raises(InvalidValueError) as excinfo:
await specific_search(query_type, query, mock_user)
assert "Unsupported search type" in str(excinfo.value)