test: add tests for batch query graph completions

This commit is contained in:
Andrej Milicevic 2026-01-16 11:57:03 +01:00
parent b05c93bf5f
commit 98e8d226eb
5 changed files with 567 additions and 4 deletions

View file

@ -171,7 +171,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: Optional[str] = None,
context: Optional[List[Edge]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
max_iter=4,
response_model: Type = str,
@ -216,16 +216,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
completion_results = []
if query_batch and len(query_batch) > 0:
if not context:
# Having a list is necessary to zip through it
context = []
for query in query_batch:
context.append(None)
completion_results = await asyncio.gather(
*[
self._run_cot_completion(
query=query,
context=context,
context=context_el,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
for query in query_batch
for query, context_el in zip(query_batch, context)
]
)
else:
@ -237,11 +243,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
response_model=response_model,
)
# TODO: Handle save interaction for batch queries
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
# TODO: Handle session save interaction for batch queries
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)

View file

@ -262,7 +262,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
session_id=session_id,
)
return [completion]
return completion if isinstance(completion, list) else [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""

View file

@ -1,4 +1,5 @@
import pytest
from itertools import cycle
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
@ -467,3 +468,272 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_provided_context(mock_edge):
"""Test get_completion batch queries uses provided context."""
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
"""Test get_completion batch queries with multiple context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge):
"""Test get_completion batch queries stops early when no new triplets found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# When get_context returns same triplets, the loop should stop early
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=4,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge):
"""Test get_completion batch queries with zero context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=0
)
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
TestModel(answer="Test answer"),
TestModel(answer="Test answer"),
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
response_model=TestModel,
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
from itertools import cycle
from cognee.modules.retrieval.graph_completion_cot_retriever import (
GraphCompletionCotRetriever,
@ -686,3 +687,131 @@ async def test_as_answer_text_with_basemodel():
assert isinstance(result, str)
assert "[Structured Response]" in result
assert "test answer" in result
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion batch queries with provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=cycle(["validation_result", "followup_question"]),
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
max_iter=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
#
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)

View file

@ -646,3 +646,159 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion correctly handles batch queries."""
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]]
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_empty_context(mock_edge):
"""Test get_completion with empty context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[], []],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2