increases coverage for cot completion retriever
This commit is contained in:
parent
36e82909dc
commit
670a0fbb69
1 changed files with 600 additions and 10 deletions
|
|
@ -1,8 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -27,16 +29,6 @@ async def test_get_triplets_inherited(mock_edge):
|
||||||
assert triplets[0] == mock_edge
|
assert triplets[0] == mock_edge
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_init_defaults():
|
|
||||||
"""Test GraphCompletionCotRetriever initialization with defaults."""
|
|
||||||
retriever = GraphCompletionCotRetriever()
|
|
||||||
|
|
||||||
assert retriever.top_k == 5
|
|
||||||
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
|
||||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_init_custom_params():
|
async def test_init_custom_params():
|
||||||
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
|
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
|
||||||
|
|
@ -44,8 +36,606 @@ async def test_init_custom_params():
|
||||||
top_k=10,
|
top_k=10,
|
||||||
user_prompt_path="custom_user.txt",
|
user_prompt_path="custom_user.txt",
|
||||||
system_prompt_path="custom_system.txt",
|
system_prompt_path="custom_system.txt",
|
||||||
|
validation_user_prompt_path="custom_validation_user.txt",
|
||||||
|
validation_system_prompt_path="custom_validation_system.txt",
|
||||||
|
followup_system_prompt_path="custom_followup_system.txt",
|
||||||
|
followup_user_prompt_path="custom_followup_user.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert retriever.top_k == 10
|
assert retriever.top_k == 10
|
||||||
assert retriever.user_prompt_path == "custom_user.txt"
|
assert retriever.user_prompt_path == "custom_user.txt"
|
||||||
assert retriever.system_prompt_path == "custom_system.txt"
|
assert retriever.system_prompt_path == "custom_system.txt"
|
||||||
|
assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
|
||||||
|
assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
|
||||||
|
assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
|
||||||
|
assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_defaults():
|
||||||
|
"""Test GraphCompletionCotRetriever initialization with defaults."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
|
||||||
|
assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
|
||||||
|
assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
|
||||||
|
assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
||||||
|
"""Test _run_cot_completion round 0 with provided context."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
return_value="Rendered prompt",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||||
|
return_value="System prompt",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway,
|
||||||
|
"acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=["validation_result", "followup_question"],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=[mock_edge],
|
||||||
|
max_iter=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion == "Generated answer"
|
||||||
|
assert context_text == "Resolved context"
|
||||||
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
||||||
|
"""Test _run_cot_completion round 0 without provided context."""
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
|
return_value=mock_graph_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
|
return_value=[mock_edge],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=None,
|
||||||
|
max_iter=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion == "Generated answer"
|
||||||
|
assert context_text == "Resolved context"
|
||||||
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_multiple_rounds(mock_edge):
|
||||||
|
"""Test _run_cot_completion with multiple rounds."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
mock_edge2 = MagicMock(spec=Edge)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
retriever,
|
||||||
|
"get_context",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=[[mock_edge], [mock_edge2]],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
return_value="Rendered prompt",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||||
|
return_value="System prompt",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway,
|
||||||
|
"acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=[
|
||||||
|
"validation_result",
|
||||||
|
"followup_question",
|
||||||
|
"validation_result2",
|
||||||
|
"followup_question2",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=[mock_edge],
|
||||||
|
max_iter=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion == "Generated answer"
|
||||||
|
assert context_text == "Resolved context"
|
||||||
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_with_conversation_history(mock_edge):
|
||||||
|
"""Test _run_cot_completion with conversation history."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
) as mock_generate,
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=[mock_edge],
|
||||||
|
conversation_history="Previous conversation",
|
||||||
|
max_iter=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion == "Generated answer"
|
||||||
|
call_kwargs = mock_generate.call_args[1]
|
||||||
|
assert call_kwargs.get("conversation_history") == "Previous conversation"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_with_response_model(mock_edge):
|
||||||
|
"""Test _run_cot_completion with custom response model."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value=TestModel(answer="Test answer"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=[mock_edge],
|
||||||
|
response_model=TestModel,
|
||||||
|
max_iter=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(completion, TestModel)
|
||||||
|
assert completion.answer == "Test answer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
||||||
|
"""Test _run_cot_completion with empty conversation history."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
) as mock_generate,
|
||||||
|
):
|
||||||
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||||
|
query="test query",
|
||||||
|
context=[mock_edge],
|
||||||
|
conversation_history="",
|
||||||
|
max_iter=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion == "Generated answer"
|
||||||
|
# Verify conversation_history was passed as None when empty
|
||||||
|
call_kwargs = mock_generate.call_args[1]
|
||||||
|
assert call_kwargs.get("conversation_history") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_without_context(mock_edge):
|
||||||
|
"""Test get_completion retrieves context when not provided."""
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
|
return_value=mock_graph_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
|
return_value=[mock_edge],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
return_value="Rendered prompt",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||||
|
return_value="System prompt",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway,
|
||||||
|
"acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=["validation_result", "followup_question"],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", max_iter=1)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert completion[0] == "Generated answer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_provided_context(mock_edge):
|
||||||
|
"""Test get_completion uses provided context."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert completion[0] == "Generated answer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_session(mock_edge):
|
||||||
|
"""Test get_completion with session caching enabled."""
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "test-user-id"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
|
return_value=mock_graph_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
|
return_value=[mock_edge],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
|
||||||
|
return_value="Previous conversation",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
|
||||||
|
return_value="Context summary",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
|
||||||
|
) as mock_save,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
||||||
|
) as mock_session_user,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = True
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
mock_session_user.get.return_value = mock_user
|
||||||
|
|
||||||
|
completion = await retriever.get_completion(
|
||||||
|
"test query", session_id="test_session", max_iter=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert completion[0] == "Generated answer"
|
||||||
|
mock_save.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_save_interaction(mock_edge):
|
||||||
|
"""Test get_completion with save_interaction enabled."""
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
mock_graph_engine.add_edges = AsyncMock()
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
||||||
|
|
||||||
|
mock_node1 = MagicMock()
|
||||||
|
mock_node2 = MagicMock()
|
||||||
|
mock_edge.node1 = mock_node1
|
||||||
|
mock_edge.node2 = mock_node2
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
return_value="Rendered prompt",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||||
|
return_value="System prompt",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway,
|
||||||
|
"acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=["validation_result", "followup_question"],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||||
|
side_effect=[
|
||||||
|
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||||
|
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||||
|
) as mock_add_data,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
# Pass context so save_interaction condition is met
|
||||||
|
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
mock_add_data.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_response_model(mock_edge):
|
||||||
|
"""Test get_completion with custom response model."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
|
return_value=mock_graph_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
|
return_value=[mock_edge],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value=TestModel(answer="Test answer"),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
completion = await retriever.get_completion(
|
||||||
|
"test query", response_model=TestModel, max_iter=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert isinstance(completion[0], TestModel)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||||
|
"""Test get_completion with session config but no user ID."""
|
||||||
|
mock_graph_engine = AsyncMock()
|
||||||
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
|
return_value=mock_graph_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
|
return_value=[mock_edge],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
||||||
|
) as mock_session_user,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = True
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
mock_session_user.get.return_value = None # No user
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", max_iter=1)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
||||||
|
"""Test get_completion with save_interaction but no context provided."""
|
||||||
|
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
return_value="Resolved context",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
return_value="Rendered prompt",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||||
|
return_value="System prompt",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway,
|
||||||
|
"acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=["validation_result", "followup_question"],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||||
|
) as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", context=None, max_iter=1)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue