From a5494513d76420856d2fab2a571ec379f3c21288 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 15:11:58 +0100 Subject: [PATCH] fix: retrievers work with duplicates, added tests also, plus PR change requests done --- .../retrieval/exceptions/exceptions.py | 10 +++ ..._completion_context_extension_retriever.py | 49 +++++++++------ .../graph_completion_cot_retriever.py | 61 +++++++++---------- .../retrieval/graph_completion_retriever.py | 11 +++- cognee/modules/retrieval/utils/query_state.py | 34 +++++++++++ ...letion_retriever_context_extension_test.py | 57 +++++++++++++++++ .../graph_completion_retriever_cot_test.py | 37 ++++++++++- .../graph_completion_retriever_test.py | 40 ++++++++++++ 8 files changed, 244 insertions(+), 55 deletions(-) create mode 100644 cognee/modules/retrieval/utils/query_state.py diff --git a/cognee/modules/retrieval/exceptions/exceptions.py b/cognee/modules/retrieval/exceptions/exceptions.py index 3e934909b..0efaf7351 100644 --- a/cognee/modules/retrieval/exceptions/exceptions.py +++ b/cognee/modules/retrieval/exceptions/exceptions.py @@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError): status_code: int = status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) + + +class QueryValidationError(CogneeValidationError): + def __init__( + self, + message: str = "Queries not supplied in the correct format.", + name: str = "QueryValidationError", + status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT, + ): + super().__init__(message, name, status_code) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d005759c9..ef7708dce 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,6 +1,8 @@ import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -15,19 +17,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() -class QueryState: - """ - Helper class containing all necessary information about the query state: - the triplets and context associated with it, and also a check whether - it has fully extended the context. - """ - - def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool): - self.triplets = triplets - self.context_text = context_text - self.finished_extending_context = finished_extending_context - - class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ Handles graph context completion for question answering tasks, extending context based @@ -112,13 +101,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) triplets_batch = context @@ -190,7 +183,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch): + for batched_query, batched_new_triplets in zip( + finished_queries_states.keys(), new_triplets_batch + ): finished_queries_states[batched_query].triplets = list( dict.fromkeys( finished_queries_states[batched_query].triplets + batched_new_triplets @@ -207,7 +202,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) # Update context_texts in query states - for batched_query, batched_context_text in zip(query_batch, context_text_batch): + for batched_query, batched_context_text in zip( + finished_queries_states.keys(), context_text_batch + ): if not finished_queries_states[batched_query].finished_extending_context: finished_queries_states[batched_query].context_text = batched_context_text @@ -216,7 +213,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query_state in finished_queries_states.values() ] - for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): + for batched_query, prev_size, new_size in zip( + finished_queries_states.keys(), prev_sizes, new_sizes + ): # Mark done queries accordingly if prev_size == new_size: finished_queries_states[batched_query].finished_extending_context = True @@ -229,6 +228,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 completion_batch = [] + result_completion_batch = [] if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -260,6 +260,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) + # Make sure answers are returned for duplicate queries, in the order they were asked. + for batched_query, batched_completion in zip( + finished_queries_states.keys(), completion_batch + ): + finished_queries_states[batched_query].completion = batched_completion + + for batched_query in query_batch: + result_completion_batch.append(finished_queries_states[batched_query].completion) + # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: await self.save_qa( @@ -277,4 +286,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion_batch if completion_batch else [completion] + return result_completion_batch if result_completion_batch else [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index be28cc97c..5675c8c70 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -3,6 +3,8 @@ import json from typing import Optional, List, Type, Any from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger @@ -23,28 +25,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() -class QueryState: - """ - Helper class containing all necessary information about the query state. - Used to keep track of important information in a more readable way, and - enable as many parallel calls to llms as possible. - """ - - def __init__(self): - self.completion: str = "" - self.triplets: List[Edge] = [] - self.context_text: str = "" - - self.answer_text: str = "" - self.valid_user_prompt: str = "" - self.valid_system_prompt: str = "" - self.reasoning: str = "" - - self.followup_question: str = "" - self.followup_prompt: str = "" - self.followup_system: str = "" - - def _as_answer_text(completion: Any) -> str: """Convert completion to human-readable text for validation and follow-up prompts.""" if isinstance(completion, str): @@ -155,7 +135,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if round_idx == 0: if context is None: # Get context, resolve to text, and store info in the query state - triplets_batch = await self.get_context(query_batch=query_batch) + triplets_batch = await self.get_context( + query_batch=list(query_state_tracker.keys()) + ) context_text_batch = await asyncio.gather( *[ self.resolve_edges_to_text(batched_triplets) @@ -163,7 +145,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) for batched_query, batched_triplets, batched_context_text in zip( - query_batch, triplets_batch, context_text_batch + query_state_tracker.keys(), triplets_batch, context_text_batch ): query_state_tracker[batched_query].triplets = batched_triplets query_state_tracker[batched_query].context_text = batched_context_text @@ -176,7 +158,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) for batched_query, batched_triplets, batched_context_text in zip( - query_batch, context, context_text_batch + query_state_tracker.keys(), context, context_text_batch ): query_state_tracker[batched_query].triplets = batched_triplets query_state_tracker[batched_query].context_text = batched_context_text @@ -184,7 +166,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): # Find new triplets, and update existing query states triplets_batch = await self.get_context(query_batch=followup_question_batch) - for batched_query, batched_followup_triplets in zip(query_batch, triplets_batch): + for batched_query, batched_followup_triplets in zip( + query_state_tracker.keys(), triplets_batch + ): query_state_tracker[batched_query].triplets = list( dict.fromkeys( query_state_tracker[batched_query].triplets + batched_followup_triplets @@ -198,7 +182,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_context_text in zip(query_batch, context_text_batch): + for batched_query, batched_context_text in zip( + query_state_tracker.keys(), context_text_batch + ): query_state_tracker[batched_query].context_text = batched_context_text completion_batch = await asyncio.gather( @@ -216,9 +202,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_completion in zip(query_batch, completion_batch): + for batched_query, batched_completion in zip( + query_state_tracker.keys(), completion_batch + ): query_state_tracker[batched_query].completion = batched_completion + if round_idx == max_iter: + # When we finish all iterations: + # Make sure answers are returned for duplicate queries, in the order they were asked. + completion_batch = [] + for batched_query in query_batch: + completion_batch.append(query_state_tracker[batched_query].completion) + logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}") if round_idx < max_iter: @@ -332,13 +327,17 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) # Load conversation history if enabled conversation_history = "" diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index eae23bb0d..d9667a669 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text @@ -216,13 +217,17 @@ class GraphCompletionRetriever(BaseGraphRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) triplets = context diff --git a/cognee/modules/retrieval/utils/query_state.py b/cognee/modules/retrieval/utils/query_state.py new file mode 100644 index 000000000..a926a952e --- /dev/null +++ b/cognee/modules/retrieval/utils/query_state.py @@ -0,0 +1,34 @@ +from typing import List +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +class QueryState: + """ + Helper class containing all necessary information about the query state. + Used (for now) in COT and Context Extension Retrievers to keep track of important information + in a more readable way, and enable as many parallel calls to llms as possible. + """ + + def __init__( + self, + triplets: List[Edge] = None, + context_text: str = "", + finished_extending_context: bool = False, + ): + # Mutual fields for COT and Context Extension + self.triplets = triplets if triplets else [] + self.context_text = context_text + self.completion = "" + + # Context Extension specific + self.finished_extending_context = finished_extending_context + + # COT specific + self.answer_text: str = "" + self.valid_user_prompt: str = "" + self.valid_system_prompt: str = "" + self.reasoning: str = "" + + self.followup_question: str = "" + self.followup_prompt: str = "" + self.followup_system: str = "" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 7ce0d3a57..9ceca96e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -747,3 +747,60 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries with duplicate queries.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index aa71bd9c7..1a6155c4f 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -375,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge): mock_config.caching = False mock_cache_config.return_value = mock_config - completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1) + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) assert isinstance(completion, list) assert len(completion) == 1 @@ -818,3 +818,38 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 1"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index ae0adb729..159fa2df4 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -802,3 +802,43 @@ async def test_get_completion_batch_queries_empty_context(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer"