increases coverage in context extension retriever
This commit is contained in:
parent
670a0fbb69
commit
88307ce382
1 changed files with 416 additions and 0 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
|
|
@ -46,8 +47,423 @@ async def test_init_custom_params():
|
|||
top_k=10,
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
system_prompt="Custom prompt",
|
||||
node_type=str,
|
||||
node_name=["node1"],
|
||||
save_interaction=True,
|
||||
wide_search_top_k=200,
|
||||
triplet_distance_penalty=5.0,
|
||||
)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.system_prompt == "Custom prompt"
|
||||
assert retriever.node_type == str
|
||||
assert retriever.node_name == ["node1"]
|
||||
assert retriever.save_interaction is True
|
||||
assert retriever.wide_search_top_k == 200
|
||||
assert retriever.triplet_distance_penalty == 5.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_edge):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_edge):
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_extension_rounds(mock_edge):
|
||||
"""Test get_completion with multiple context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
# Create a second edge for extension rounds
|
||||
mock_edge2 = MagicMock(spec=Edge)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[mock_edge], [mock_edge2]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
side_effect=["Resolved context", "Extended context"], # Different contexts
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Query for extension, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_extension_stops_early(mock_edge):
|
||||
"""Test get_completion stops early when no new triplets found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
# When get_context returns same triplets, the loop should stop early
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=4
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_edge):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", session_id="test_session", context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction(mock_edge):
|
||||
"""Test get_completion with save_interaction enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
mock_add_data.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_edge):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
TestModel(answer="Test answer"),
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", response_model=TestModel, context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_zero_extension_rounds(mock_edge):
|
||||
"""Test get_completion with zero context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=0)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue