test: add tests for batch query graph completions
This commit is contained in:
parent
b05c93bf5f
commit
98e8d226eb
5 changed files with 567 additions and 4 deletions
|
|
@ -171,7 +171,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
async def get_completion(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
context: Optional[List[Edge]] = None,
|
||||
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
response_model: Type = str,
|
||||
|
|
@ -216,16 +216,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
completion_results = []
|
||||
if query_batch and len(query_batch) > 0:
|
||||
if not context:
|
||||
# Having a list is necessary to zip through it
|
||||
context = []
|
||||
for query in query_batch:
|
||||
context.append(None)
|
||||
|
||||
completion_results = await asyncio.gather(
|
||||
*[
|
||||
self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
context=context_el,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query in query_batch
|
||||
for query, context_el in zip(query_batch, context)
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -237,11 +243,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
response_model=response_model,
|
||||
)
|
||||
|
||||
# TODO: Handle save interaction for batch queries
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
# TODO: Handle session save interaction for batch queries
|
||||
# Save to session cache if enabled
|
||||
if session_save:
|
||||
context_summary = await summarize_text(context_text)
|
||||
|
|
|
|||
|
|
@ -262,7 +262,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
return completion if isinstance(completion, list) else [completion]
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from itertools import cycle
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -467,3 +468,272 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
|
|||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_without_context(mock_edge):
|
||||
"""Test get_completion batch queries retrieves context when not provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_provided_context(mock_edge):
|
||||
"""Test get_completion batch queries uses provided context."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
context=[[mock_edge], [mock_edge]],
|
||||
context_extension_rounds=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
|
||||
"""Test get_completion batch queries with multiple context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
# Create a second edge for extension rounds
|
||||
mock_edge2 = MagicMock(spec=Edge)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
"Generated answer",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge):
|
||||
"""Test get_completion batch queries stops early when no new triplets found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
"Generated answer",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
# When get_context returns same triplets, the loop should stop early
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
context=[[mock_edge], [mock_edge]],
|
||||
context_extension_rounds=4,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge):
|
||||
"""Test get_completion batch queries with zero context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], context_extension_rounds=0
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_response_model(mock_edge):
|
||||
"""Test get_completion batch queries with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Extension query",
|
||||
TestModel(answer="Test answer"),
|
||||
TestModel(answer="Test answer"),
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
response_model=TestModel,
|
||||
context_extension_rounds=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
from itertools import cycle
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import (
|
||||
GraphCompletionCotRetriever,
|
||||
|
|
@ -686,3 +687,131 @@ async def test_as_answer_text_with_basemodel():
|
|||
assert isinstance(result, str)
|
||||
assert "[Structured Response]" in result
|
||||
assert "test answer" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_context(mock_edge):
|
||||
"""Test get_completion batch queries with provided context."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=cycle(["validation_result", "followup_question"]),
|
||||
),
|
||||
):
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
context=[[mock_edge], [mock_edge]],
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_without_context(mock_edge):
|
||||
"""Test get_completion batch queries without provided context."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
):
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], max_iter=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
#
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_response_model(mock_edge):
|
||||
"""Test get_completion of batch queries with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
|
||||
|
|
|
|||
|
|
@ -646,3 +646,159 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge
|
|||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_add_data.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_context(mock_edge):
|
||||
"""Test get_completion correctly handles batch queries."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]]
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_without_context(mock_edge):
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_response_model(mock_edge):
|
||||
"""Test get_completion of batch queries with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"], response_model=TestModel
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_empty_context(mock_edge):
|
||||
"""Test get_completion with empty context."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[[], []],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue