diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4c9122f85..89e8b80a2 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -15,6 +15,19 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() +class QueryState: + """ + Helper class containing all necessary information about the query state: + the triplets and context associated with it, and also a check whether + it has fully extended the context. + """ + + def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool): + self.triplets = triplets + self.context_text = context_text + self.finished_extending_context = finished_extending_context + + class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ Handles graph context completion for question answering tasks, extending context based @@ -91,10 +104,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - # TODO: This may be unnecessary in this retriever, will check later - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) triplets_batch = context @@ -117,70 +129,87 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx = 1 - # We will be removing queries, and their associated triplets and context, as we go - # through iterations, so we need to save their final states for the final generation. - # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. - original_query_batch = query_batch - finished_queries_data = {} - for i, batched_query in enumerate(query_batch): - if not triplets_batch[i]: - query_batch[i] = "" - else: - finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i]) + # We store queries as keys and their associated states in this dict. + # The state is a 3-item object QueryState, which holds triplets, context text, + # and a boolean marking whether we should continue extending the context for that query. + finished_queries_states = {} + + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, triplets_batch, context_text_batch + ): + # Populating the dict at the start with initial information. + finished_queries_states[batched_query] = QueryState( + batched_triplets, batched_context_text, False + ) while round_idx <= context_extension_rounds: logger.info( f"Context extension: round {round_idx} - generating next graph locational query." ) - # Filter out the queries that cannot be extended further, and their associated contexts - query_batch = [query for query in query_batch if query] - triplets_batch = [triplets for triplets in triplets_batch if triplets] - context_text_batch = [ - context_text for context_text in context_text_batch if context_text - ] - if len(query_batch) == 0: + if all( + batched_query_state.finished_extending_context + for batched_query_state in finished_queries_states.values() + ): + # We stop early only if all queries in the batch have reached their final state logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break - prev_sizes = [len(triplets) for triplets in triplets_batch] + prev_sizes = [ + len(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] completions = await asyncio.gather( *[ generate_completion( - query=query, - context=context, + query=batched_query, + context=batched_query_state.context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for query, context in zip(query_batch, context_text_batch) + for batched_query, batched_query_state in finished_queries_states.items() + if not batched_query_state.finished_extending_context ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)): - triplets += new_triplets - triplets_batch[i] = list(dict.fromkeys(triplets)) + for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch): + finished_queries_states[batched_query].triplets = list( + dict.fromkeys( + finished_queries_states[batched_query].triplets + batched_new_triplets + ) + ) + # Resolve new triplets to text context_text_batch = await asyncio.gather( - *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] + *[ + self.resolve_edges_to_text(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] ) - new_sizes = [len(triplets) for triplets in triplets_batch] + # Update context_texts in query states + for batched_query, batched_context_text in zip(query_batch, context_text_batch): + if not finished_queries_states[batched_query].finished_extending_context: + finished_queries_states[batched_query].context_text = batched_context_text - for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate( - zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch) - ): - finished_queries_data[query] = (triplets, context_text) + new_sizes = [ + len(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] + + for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): + # Mark done queries accordingly if prev_size == new_size: - # In this case, we can stop trying to extend the context of this query - query_batch[i] = "" - triplets_batch[i] = [] - context_text_batch[i] = "" + finished_queries_states[batched_query].finished_extending_context = True logger.info( f"Context extension: round {round_idx} - " @@ -189,15 +218,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - # Reset variables for the final generations. They contain the final state - # of triplets and contexts for each query, after all extension iterations. - query_batch = original_query_batch - triplets_batch = [] - context_text_batch = [] - for batched_query in query_batch: - triplets_batch.append(finished_queries_data[batched_query][0]) - context_text_batch.append(finished_queries_data[batched_query][1]) - # Check if we need to generate context summary for caching cache_config = CacheConfig() user = session_user.get() @@ -226,13 +246,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): *[ generate_completion( query=batched_query, - context=batched_context_text, + context=batched_query_state.context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, response_model=response_model, ) - for batched_query, batched_context_text in zip(query_batch, context_text_batch) + for batched_query, batched_query_state in finished_queries_states.items() ], ) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 4ecdc910a..58c9bad15 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -23,6 +23,27 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() +class QueryState: + """ + Helper class containing all necessary information about the query state. + Used to keep track of important information in a more readable way, and + enable as many parallel calls to llms as possible. + """ + + completion: str = "" + triplets: List[Edge] = [] + context_text: str = "" + + answer_text: str = "" + valid_user_prompt: str = "" + valid_system_prompt: str = "" + reasoning: str = "" + + followup_question: str = "" + followup_prompt: str = "" + followup_system: str = "" + + def _as_answer_text(completion: Any) -> str: """Convert completion to human-readable text for validation and follow-up prompts.""" if isinstance(completion, str): @@ -87,12 +108,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def _run_cot_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, conversation_history: str = "", max_iter: int = 4, response_model: Type = str, - ) -> tuple[Any, str, List[Edge]]: + ) -> tuple[List[Any], List[str], List[List[Edge]]]: """ Run chain-of-thought completion with optional structured output. @@ -110,64 +132,158 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - context_text: The resolved context text - triplets: The list of triplets used """ - followup_question = "" - triplets = [] - completion = "" + followup_question_batch = [] + completion_batch = [] + context_text_batch = [] + + if query: + # Treat a single query as a batch of queries, mainly avoiding massive code duplication + query_batch = [query] + if context: + context = [context] + + triplets_batch = context + + # dict containing query -> QueryState key-value pairs + # For every query, we save necessary data so we can execute requests in parallel + query_state_tracker = {} + for batched_query in query_batch: + query_state_tracker[batched_query] = QueryState() for round_idx in range(max_iter + 1): if round_idx == 0: if context is None: - triplets = await self.get_context(query) - context_text = await self.resolve_edges_to_text(triplets) + # Get context, resolve to text, and store info in the query state + triplets_batch = await self.get_context(query_batch=query_batch) + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_triplets) + for batched_triplets in triplets_batch + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, triplets_batch, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - context_text = await self.resolve_edges_to_text(context) + # In this case just resolve to text and save to the query state + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_context) + for batched_context in context + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, context, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - triplets += await self.get_context(followup_question) - context_text = await self.resolve_edges_to_text(list(set(triplets))) + # Find new triplets, and update existing query states + followup_triplets_batch = await self.get_context( + query_batch=followup_question_batch + ) + for batched_query, batched_followup_triplets in zip( + query_batch, followup_triplets_batch + ): + query_state_tracker[batched_query].triplets = list( + dict.fromkeys( + query_state_tracker[batched_query].triplets + batched_followup_triplets + ) + ) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - conversation_history=conversation_history if conversation_history else None, - response_model=response_model, + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_query_state.triplets) + for batched_query_state in query_state_tracker.values() + ] + ) + + for batched_query, batched_context_text in zip(query_batch, context_text_batch): + query_state_tracker[batched_query].context_text = batched_context_text + + completion_batch = await asyncio.gather( + *[ + generate_completion( + query=batched_query, + context=batched_query_state.context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + conversation_history=conversation_history if conversation_history else None, + response_model=response_model, + ) + for batched_query, batched_query_state in query_state_tracker.items() + ] ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") + for batched_query, batched_completion in zip(query_batch, completion_batch): + query_state_tracker[batched_query].completion = batched_completion + + logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}") if round_idx < max_iter: - answer_text = _as_answer_text(completion) - valid_args = {"query": query, "answer": answer_text, "context": context_text} - valid_user_prompt = render_prompt( - filename=self.validation_user_prompt_path, context=valid_args - ) - valid_system_prompt = read_query_prompt( - prompt_file_name=self.validation_system_prompt_path + for batched_query, batched_query_state in query_state_tracker.items(): + batched_query_state.answer_text = _as_answer_text( + batched_query_state.completion + ) + valid_args = { + "query": batched_query, + "answer": batched_query_state.answer_text, + "context": batched_query_state.context_text, + } + batched_query_state.valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, + context=valid_args, + ) + batched_query_state.valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.valid_user_prompt, + system_prompt=batched_query_state.valid_system_prompt, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) - reasoning = await LLMGateway.acreate_structured_output( - text_input=valid_user_prompt, - system_prompt=valid_system_prompt, - response_model=str, - ) - followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} - followup_prompt = render_prompt( - filename=self.followup_user_prompt_path, context=followup_args - ) - followup_system = read_query_prompt( - prompt_file_name=self.followup_system_prompt_path - ) + for batched_query, batched_reasoning in zip(query_batch, reasoning_batch): + query_state_tracker[batched_query].reasoning = batched_reasoning - followup_question = await LLMGateway.acreate_structured_output( - text_input=followup_prompt, system_prompt=followup_system, response_model=str + for batched_query, batched_query_state in query_state_tracker.items(): + followup_args = { + "query": query, + "answer": batched_query_state.answer_text, + "reasoning": batched_query_state.reasoning, + } + batched_query_state.followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, + context=followup_args, + ) + batched_query_state.followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.followup_prompt, + system_prompt=batched_query_state.followup_system, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) logger.info( - f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}" ) - return completion, context_text, triplets + return completion_batch, context_text_batch, triplets_batch async def get_completion( self, @@ -204,9 +320,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) # Check if session saving is enabled cache_config = CacheConfig() @@ -219,40 +335,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - context_batch = context - completion_batch = [] - if query_batch and len(query_batch) > 0: - if not context_batch: - # Having a list is necessary to zip through it - context_batch = [] - for _ in query_batch: - context_batch.append(None) - - completion_batch = await asyncio.gather( - *[ - self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) - for batched_query, context in zip(query_batch, context_batch) - ] - ) - else: - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) + completion, context_text, triplets = await self._run_cot_completion( + query=query, + query_batch=query_batch, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: await self.save_qa( - question=query, answer=str(completion), context=context_text, triplets=triplets + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], ) # TODO: Handle session save interaction for batch queries @@ -266,7 +364,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - if completion_batch: - return [completion for completion, _, _ in completion_batch] - - return [completion] + return completion diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index a1b4c3833..711c7ab40 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -158,15 +158,18 @@ class GraphCompletionRetriever(BaseGraphRetriever): f"Empty context was provided to the completion for the query: {batched_query}" ) entity_nodes_batch = [] + for batched_triplets in triplets: entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) - await asyncio.gather( - *[ - update_node_access_timestamps(batched_entity_nodes) - for batched_entity_nodes in entity_nodes_batch - ] - ) + # Remove duplicates and update node access, if it is enabled + for batched_entity_nodes in entity_nodes_batch: + # from itertools import chain + # + # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) + # entity_nodes = list(set(flattened_entity_nodes)) + + await update_node_access_timestamps(batched_entity_nodes) else: if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -209,9 +212,9 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) triplets = context diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index c0a0e7fab..764f8d77d 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -147,9 +147,9 @@ async def brute_force_triplet_search( In single-query mode, node_distances and edge_distances are stored as flat lists. In batch mode, they are stored as list-of-lists (one list per query). """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) if top_k <= 0: raise ValueError("top_k must be a positive integer.") diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py index 0db035e03..54ca09b8a 100644 --- a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -1,4 +1,3 @@ -import os import pytest import pathlib import pytest_asyncio diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 5ae901594..aa71bd9c7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,3 +1,5 @@ +import os + import pytest from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID @@ -80,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -106,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -126,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -143,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -168,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", @@ -200,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): max_iter=2, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -227,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") == "Previous conversation" @@ -259,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge): max_iter=1, ) - assert isinstance(completion, TestModel) - assert completion.answer == "Test answer" + assert isinstance(completion, list) + assert isinstance(completion[0], TestModel) + assert completion[0].answer == "Test answer" @pytest.mark.asyncio @@ -285,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] # Verify conversation_history was passed as None when empty call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") is None @@ -306,7 +309,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -316,7 +319,7 @@ async def test_get_completion_without_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -372,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge): mock_config.caching = False mock_cache_config.return_value = mock_config - completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1) assert isinstance(completion, list) assert len(completion) == 1 @@ -397,7 +400,7 @@ async def test_get_completion_with_session(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -463,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -528,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -570,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -612,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -703,7 +706,7 @@ async def test_get_completion_batch_queries_with_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -749,7 +752,7 @@ async def test_get_completion_batch_queries_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -790,7 +793,7 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",