diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index d3df83516..b581f14d7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,8 +1,10 @@ import pytest from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.infrastructure.llm.LLMGateway import LLMGateway @pytest.fixture @@ -27,16 +29,6 @@ async def test_get_triplets_inherited(mock_edge): assert triplets[0] == mock_edge -@pytest.mark.asyncio -async def test_init_defaults(): - """Test GraphCompletionCotRetriever initialization with defaults.""" - retriever = GraphCompletionCotRetriever() - - assert retriever.top_k == 5 - assert retriever.user_prompt_path == "graph_context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - - @pytest.mark.asyncio async def test_init_custom_params(): """Test GraphCompletionCotRetriever initialization with custom parameters.""" @@ -44,8 +36,606 @@ async def test_init_custom_params(): top_k=10, user_prompt_path="custom_user.txt", system_prompt_path="custom_system.txt", + validation_user_prompt_path="custom_validation_user.txt", + validation_system_prompt_path="custom_validation_system.txt", + followup_system_prompt_path="custom_followup_system.txt", + followup_user_prompt_path="custom_followup_user.txt", ) assert retriever.top_k == 10 assert retriever.user_prompt_path == "custom_user.txt" assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.validation_user_prompt_path == "custom_validation_user.txt" + assert retriever.validation_system_prompt_path == "custom_validation_system.txt" + assert retriever.followup_system_prompt_path == "custom_followup_system.txt" + assert retriever.followup_user_prompt_path == "custom_followup_user.txt" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() + + assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" + assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" + assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" + assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_with_context(mock_edge): + """Test _run_cot_completion round 0 with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_without_context(mock_edge): + """Test _run_cot_completion round 0 without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=None, + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_multiple_rounds(mock_edge): + """Test _run_cot_completion with multiple rounds.""" + retriever = GraphCompletionCotRetriever() + + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "followup_question", + "validation_result2", + "followup_question2", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=2, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_conversation_history(mock_edge): + """Test _run_cot_completion with conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="Previous conversation", + max_iter=1, + ) + + assert completion == "Generated answer" + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") == "Previous conversation" + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_response_model(mock_edge): + """Test _run_cot_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + response_model=TestModel, + max_iter=1, + ) + + assert isinstance(completion, TestModel) + assert completion.answer == "Test answer" + + +@pytest.mark.asyncio +async def test_run_cot_completion_empty_conversation_history(mock_edge): + """Test _run_cot_completion with empty conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="", + max_iter=1, + ) + + assert completion == "Generated answer" + # Verify conversation_history was passed as None when empty + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") is None + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + retriever = GraphCompletionCotRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None, max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1