diff --git a/cognee/modules/retrieval/exceptions/exceptions.py b/cognee/modules/retrieval/exceptions/exceptions.py index 3e934909b..0efaf7351 100644 --- a/cognee/modules/retrieval/exceptions/exceptions.py +++ b/cognee/modules/retrieval/exceptions/exceptions.py @@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError): status_code: int = status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) + + +class QueryValidationError(CogneeValidationError): + def __init__( + self, + message: str = "Queries not supplied in the correct format.", + name: str = "QueryValidationError", + status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT, + ): + super().__init__(message, name, status_code) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index fc49a139b..0dc3a8bf6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,6 +1,9 @@ import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text @@ -56,11 +59,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -89,47 +93,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - triplets = context - - if triplets is None: - triplets = await self.get_context(query) - - context_text = await self.resolve_edges_to_text(triplets) - - round_idx = 1 - - while round_idx <= context_extension_rounds: - prev_size = len(triplets) - - logger.info( - f"Context extension: round {round_idx} - generating next graph locational query." - ) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) - - triplets += await self.get_context(completion) - triplets = list(set(triplets)) - context_text = await self.resolve_edges_to_text(triplets) - - num_triplets = len(triplets) - - if num_triplets == prev_size: - logger.info( - f"Context extension: round {round_idx} – no new triplets found; stopping early." - ) - break - - logger.info( - f"Context extension: round {round_idx} - " - f"number of unique retrieved triplets: {num_triplets}" - ) - - round_idx += 1 # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -137,6 +100,131 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + if query_batch and session_save: + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) + + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise QueryValidationError(message=msg) + + triplets_batch = context + + if query: + # This is done mostly to avoid duplicating a lot of code unnecessarily + query_batch = [query] + if triplets_batch: + triplets_batch = [triplets_batch] + + if triplets_batch is None: + triplets_batch = await self.get_context(query_batch=query_batch) + + if not triplets_batch: + return [] + + context_text = "" + context_text_batch = await asyncio.gather( + *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] + ) + + round_idx = 1 + + # We store queries as keys and their associated states in this dict. + # The state is a 3-item object QueryState, which holds triplets, context text, + # and a boolean marking whether we should continue extending the context for that query. + finished_queries_states = {} + + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, triplets_batch, context_text_batch + ): + # Populating the dict at the start with initial information. + finished_queries_states[batched_query] = QueryState( + batched_triplets, batched_context_text, False + ) + + while round_idx <= context_extension_rounds: + logger.info( + f"Context extension: round {round_idx} - generating next graph locational query." + ) + + if all( + batched_query_state.finished_extending_context + for batched_query_state in finished_queries_states.values() + ): + # We stop early only if all queries in the batch have reached their final state + logger.info( + f"Context extension: round {round_idx} – no new triplets found; stopping early." + ) + break + + relevant_queries = [ + rel_query + for rel_query in finished_queries_states.keys() + if not finished_queries_states[rel_query].finished_extending_context + ] + + prev_sizes = [ + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries + ] + + completions = await asyncio.gather( + *[ + generate_completion( + query=rel_query, + context=finished_queries_states[rel_query].context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + for rel_query in relevant_queries + ], + ) + + # Get new triplets, and merge them with existing ones, filtering out duplicates + new_triplets_batch = await self.get_context(query_batch=completions) + for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch): + finished_queries_states[rel_query].triplets = list( + dict.fromkeys( + finished_queries_states[rel_query].triplets + batched_new_triplets + ) + ) + + # Resolve new triplets to text + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(finished_queries_states[rel_query].triplets) + for rel_query in relevant_queries + ] + ) + + # Update context_texts in query states + for rel_query, batched_context_text in zip(relevant_queries, context_text_batch): + finished_queries_states[rel_query].context_text = batched_context_text + + new_sizes = [ + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries + ] + + for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes): + # Mark done queries accordingly + if prev_size == new_size: + finished_queries_states[rel_query].finished_extending_context = True + + logger.info( + f"Context extension: round {round_idx} - " + f"number of unique retrieved triplets for each query : {new_sizes}" + ) + + round_idx += 1 + + completion_batch = [] + result_completion_batch = [] + if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -153,18 +241,36 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, + completion_batch = await asyncio.gather( + *[ + generate_completion( + query=batched_query, + context=batched_query_state.context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for batched_query, batched_query_state in finished_queries_states.items() + ], ) - if self.save_interaction and context_text and triplets and completion: + # Make sure answers are returned for duplicate queries, in the order they were asked. + for batched_query, batched_completion in zip( + finished_queries_states.keys(), completion_batch + ): + finished_queries_states[batched_query].completion = batched_completion + + for batched_query in query_batch: + result_completion_batch.append(finished_queries_states[batched_query].completion) + + # TODO: Do batch queries for save interaction + if self.save_interaction and context_text_batch and triplets_batch and completion_batch: await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets + question=query, + answer=completion_batch[0], + context=context_text_batch[0], + triplets=triplets_batch[0], ) if session_save: @@ -175,4 +281,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return [completion] + return result_completion_batch if result_completion_batch else [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 70fcb6cdb..114578aa9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -3,6 +3,9 @@ import json from typing import Optional, List, Type, Any from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -86,12 +89,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def _run_cot_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, conversation_history: str = "", max_iter: int = 4, response_model: Type = str, - ) -> tuple[Any, str, List[Edge]]: + ) -> tuple[List[Any], List[str], List[List[Edge]]]: """ Run chain-of-thought completion with optional structured output. @@ -109,72 +113,187 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - context_text: The resolved context text - triplets: The list of triplets used """ - followup_question = "" - triplets = [] - completion = "" + followup_question_batch = [] + completion_batch = [] + context_text_batch = [] + + if query: + # Treat a single query as a batch of queries, mainly avoiding massive code duplication + query_batch = [query] + if context: + context = [context] + + triplets_batch = context + + # dict containing query -> QueryState key-value pairs + # For every query, we save necessary data so we can execute requests in parallel + query_state_tracker = {} + for batched_query in query_batch: + query_state_tracker[batched_query] = QueryState() for round_idx in range(max_iter + 1): if round_idx == 0: if context is None: - triplets = await self.get_context(query) - context_text = await self.resolve_edges_to_text(triplets) + # Get context, resolve to text, and store info in the query state + triplets_batch = await self.get_context( + query_batch=list(query_state_tracker.keys()) + ) + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_triplets) + for batched_triplets in triplets_batch + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_state_tracker.keys(), triplets_batch, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - context_text = await self.resolve_edges_to_text(context) + # In this case just resolve to text and save to the query state + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_context) + for batched_context in context + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_state_tracker.keys(), context, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - triplets += await self.get_context(followup_question) - context_text = await self.resolve_edges_to_text(list(set(triplets))) + # Find new triplets, and update existing query states + triplets_batch = await self.get_context(query_batch=followup_question_batch) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - conversation_history=conversation_history if conversation_history else None, - response_model=response_model, + for batched_query, batched_followup_triplets in zip( + query_state_tracker.keys(), triplets_batch + ): + query_state_tracker[batched_query].triplets = list( + dict.fromkeys( + query_state_tracker[batched_query].triplets + batched_followup_triplets + ) + ) + + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_query_state.triplets) + for batched_query_state in query_state_tracker.values() + ] + ) + + for batched_query, batched_context_text in zip( + query_state_tracker.keys(), context_text_batch + ): + query_state_tracker[batched_query].context_text = batched_context_text + + completion_batch = await asyncio.gather( + *[ + generate_completion( + query=batched_query, + context=batched_query_state.context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + conversation_history=conversation_history if conversation_history else None, + response_model=response_model, + ) + for batched_query, batched_query_state in query_state_tracker.items() + ] ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") + for batched_query, batched_completion in zip( + query_state_tracker.keys(), completion_batch + ): + query_state_tracker[batched_query].completion = batched_completion + + if round_idx == max_iter: + # When we finish all iterations: + # Make sure answers are returned for duplicate queries, in the order they were asked. + completion_batch = [] + for batched_query in query_batch: + completion_batch.append(query_state_tracker[batched_query].completion) + + logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}") if round_idx < max_iter: - answer_text = _as_answer_text(completion) - valid_args = {"query": query, "answer": answer_text, "context": context_text} - valid_user_prompt = render_prompt( - filename=self.validation_user_prompt_path, context=valid_args - ) - valid_system_prompt = read_query_prompt( - prompt_file_name=self.validation_system_prompt_path + for batched_query, batched_query_state in query_state_tracker.items(): + batched_query_state.answer_text = _as_answer_text( + batched_query_state.completion + ) + valid_args = { + "query": batched_query, + "answer": batched_query_state.answer_text, + "context": batched_query_state.context_text, + } + batched_query_state.valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, + context=valid_args, + ) + batched_query_state.valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.valid_user_prompt, + system_prompt=batched_query_state.valid_system_prompt, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) - reasoning = await LLMGateway.acreate_structured_output( - text_input=valid_user_prompt, - system_prompt=valid_system_prompt, - response_model=str, - ) - followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} - followup_prompt = render_prompt( - filename=self.followup_user_prompt_path, context=followup_args - ) - followup_system = read_query_prompt( - prompt_file_name=self.followup_system_prompt_path + for batched_query, batched_reasoning in zip( + query_state_tracker.keys(), reasoning_batch + ): + query_state_tracker[batched_query].reasoning = batched_reasoning + + for batched_query, batched_query_state in query_state_tracker.items(): + followup_args = { + "query": batched_query, + "answer": batched_query_state.answer_text, + "reasoning": batched_query_state.reasoning, + } + batched_query_state.followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, + context=followup_args, + ) + batched_query_state.followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.followup_prompt, + system_prompt=batched_query_state.followup_system, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) - followup_question = await LLMGateway.acreate_structured_output( - text_input=followup_prompt, system_prompt=followup_system, response_model=str - ) + for batched_query, batched_followup_question in zip( + query_state_tracker.keys(), followup_question_batch + ): + query_state_tracker[batched_query].followup_question = batched_followup_question + logger.info( - f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}" ) - return completion, context_text, triplets + return completion_batch, context_text_batch, triplets_batch async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, max_iter=4, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Generate completion responses based on a user query and contextual information. @@ -202,12 +321,26 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ + # Check if session saving is enabled cache_config = CacheConfig() user = session_user.get() user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + if query_batch and session_save: + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) + + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise QueryValidationError(message=msg) + # Load conversation history if enabled conversation_history = "" if session_save: @@ -215,17 +348,23 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): completion, context_text, triplets = await self._run_cot_completion( query=query, + query_batch=query_batch, context=context, conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) + # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: await self.save_qa( - question=query, answer=str(completion), context=context_text, triplets=triplets + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], ) + # TODO: Handle session save interaction for batch queries # Save to session cache if enabled if session_save: context_summary = await summarize_text(context_text) @@ -236,4 +375,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - return [completion] + return completion diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index bb8b34327..d9667a669 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -4,6 +4,8 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses @@ -79,7 +81,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): """ return await resolve_edges_to_text(retrieved_edges) - async def get_triplets(self, query: str) -> List[Edge]: + async def get_triplets( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves relevant graph triplets based on a query string. @@ -107,6 +113,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): found_triplets = await brute_force_triplet_search( query, + query_batch, top_k=self.top_k, collections=vector_index_collections or None, node_type=self.node_type, @@ -117,7 +124,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): return found_triplets - async def get_context(self, query: str) -> List[Edge]: + async def get_context( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,17 +150,36 @@ class GraphCompletionRetriever(BaseGraphRetriever): logger.warning("Search attempt on an empty knowledge graph") return [] - triplets = await self.get_triplets(query) + triplets = await self.get_triplets(query, query_batch) - if len(triplets) == 0: - logger.warning("Empty context was provided to the completion") - return [] + if query_batch: + for batched_triplets, batched_query in zip(triplets, query_batch): + if len(batched_triplets) == 0: + logger.warning( + f"Empty context was provided to the completion for the query: {batched_query}" + ) + entity_nodes_batch = [] - # context = await self.resolve_edges_to_text(triplets) + for batched_triplets in triplets: + entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) - entity_nodes = get_entity_nodes_from_triplets(triplets) + # Remove duplicates and update node access, if it is enabled + import os + + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true": + for batched_entity_nodes in entity_nodes_batch: + await update_node_access_timestamps(batched_entity_nodes) + else: + if len(triplets) == 0: + logger.warning("Empty context was provided to the completion") + return [] + + # context = await self.resolve_edges_to_text(triplets) + + entity_nodes = get_entity_nodes_from_triplets(triplets) + + await update_node_access_timestamps(entity_nodes) - await update_node_access_timestamps(entity_nodes) return triplets async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): @@ -158,10 +188,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Generates a completion using graph connections context based on a query. @@ -180,18 +211,38 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ - triplets = context - - if triplets is None: - triplets = await self.get_context(query) - - context_text = await resolve_edges_to_text(triplets) - cache_config = CacheConfig() user = session_user.get() user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + if query_batch and session_save: + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) + + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise QueryValidationError(message=msg) + + triplets = context + + if triplets is None: + triplets = await self.get_context(query, query_batch) + + context_text = "" + context_text_batch = [] + if triplets and isinstance(triplets[0], list): + context_text_batch = await asyncio.gather( + *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + else: + context_text = await resolve_edges_to_text(triplets) + if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -208,14 +259,29 @@ class GraphCompletionRetriever(BaseGraphRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, - ) + if query_batch and len(query_batch) > 0: + completion = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for query, context in zip(query_batch, context_text_batch) + ], + ) + else: + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( @@ -230,7 +296,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): session_id=session_id, ) - return [completion] + return completion if isinstance(completion, list) else [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index ce84c1423..764f8d77d 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,5 +1,6 @@ from typing import List, Optional, Type, Union +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine @@ -146,17 +147,10 @@ async def brute_force_triplet_search( In single-query mode, node_distances and edge_distances are stored as flat lists. In batch mode, they are stored as list-of-lists (one list per query). """ - if query is not None and query_batch is not None: - raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") - if query is None and query_batch is None: - raise ValueError("Must provide either 'query' or 'query_batch'.") - if query is not None and (not query or not isinstance(query, str)): - raise ValueError("The query must be a non-empty string.") - if query_batch is not None: - if not isinstance(query_batch, list) or not query_batch: - raise ValueError("query_batch must be a non-empty list of strings.") - if not all(isinstance(q, str) and q for q in query_batch): - raise ValueError("All items in query_batch must be non-empty strings.") + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) + if top_k <= 0: raise ValueError("top_k must be a positive integer.") diff --git a/cognee/modules/retrieval/utils/query_state.py b/cognee/modules/retrieval/utils/query_state.py new file mode 100644 index 000000000..a926a952e --- /dev/null +++ b/cognee/modules/retrieval/utils/query_state.py @@ -0,0 +1,34 @@ +from typing import List +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +class QueryState: + """ + Helper class containing all necessary information about the query state. + Used (for now) in COT and Context Extension Retrievers to keep track of important information + in a more readable way, and enable as many parallel calls to llms as possible. + """ + + def __init__( + self, + triplets: List[Edge] = None, + context_text: str = "", + finished_extending_context: bool = False, + ): + # Mutual fields for COT and Context Extension + self.triplets = triplets if triplets else [] + self.context_text = context_text + self.completion = "" + + # Context Extension specific + self.finished_extending_context = finished_extending_context + + # COT specific + self.answer_text: str = "" + self.valid_user_prompt: str = "" + self.valid_system_prompt: str = "" + self.reasoning: str = "" + + self.followup_question: str = "" + self.followup_prompt: str = "" + self.followup_system: str = "" diff --git a/cognee/modules/retrieval/utils/validate_queries.py b/cognee/modules/retrieval/utils/validate_queries.py new file mode 100644 index 000000000..913b0d665 --- /dev/null +++ b/cognee/modules/retrieval/utils/validate_queries.py @@ -0,0 +1,14 @@ +def validate_queries(query, query_batch) -> tuple[bool, str]: + if query is not None and query_batch is not None: + return False, "Cannot provide both 'query' and 'query_batch'; use exactly one." + if query is None and query_batch is None: + return False, "Must provide either 'query' or 'query_batch'." + if query is not None and (not query or not isinstance(query, str)): + return False, "The query must be a non-empty string." + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + return False, "query_batch must be a non-empty list of strings." + if not all(isinstance(q, str) and q for q in query_batch): + return False, "All items in query_batch must be non-empty strings." + + return True, "" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py index 0db035e03..54ca09b8a 100644 --- a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -1,4 +1,3 @@ -import os import pytest import pathlib import pytest_asyncio diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 6a9b07d38..9ceca96e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -1,4 +1,5 @@ import pytest +from itertools import cycle from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID @@ -81,7 +82,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -157,7 +158,7 @@ async def test_get_completion_context_extension_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -194,7 +195,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge): retriever = GraphCompletionContextExtensionRetriever() with ( - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -240,7 +241,7 @@ async def test_get_completion_with_session(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -304,7 +305,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -361,7 +362,7 @@ async def test_get_completion_with_response_model(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -403,7 +404,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -446,7 +447,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -467,3 +468,339 @@ async def test_get_completion_zero_extension_rounds(mock_edge): assert isinstance(completion, list) assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_provided_context(mock_edge): + """Test get_completion batch queries uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_rounds(mock_edge): + """Test get_completion batch queries with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge): + """Test get_completion batch queries stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=4, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge): + """Test get_completion batch queries with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=0 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + TestModel(answer="Test answer"), + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + response_model=TestModel, + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries with duplicate queries.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 9f3147512..1a6155c4f 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,6 +1,9 @@ +import os + import pytest from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID +from itertools import cycle from cognee.modules.retrieval.graph_completion_cot_retriever import ( GraphCompletionCotRetriever, @@ -79,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -105,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -125,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -142,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -167,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", @@ -199,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): max_iter=2, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -226,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") == "Previous conversation" @@ -258,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge): max_iter=1, ) - assert isinstance(completion, TestModel) - assert completion.answer == "Test answer" + assert isinstance(completion, list) + assert isinstance(completion[0], TestModel) + assert completion[0].answer == "Test answer" @pytest.mark.asyncio @@ -284,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] # Verify conversation_history was passed as None when empty call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") is None @@ -305,7 +309,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -315,7 +319,7 @@ async def test_get_completion_without_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -396,7 +400,7 @@ async def test_get_completion_with_session(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -462,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -527,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -569,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -611,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -686,3 +690,166 @@ async def test_as_answer_text_with_basemodel(): assert isinstance(result, str) assert "[Structured Response]" in result assert "test answer" in result + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion batch queries with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=cycle(["validation_result", "followup_question"]), + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + max_iter=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +# +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 1"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index c22f30fd0..a6fb05270 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -646,3 +646,199 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge assert len(completion) == 1 assert completion[0] == "Generated answer" mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion correctly handles batch queries.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]] + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[], []], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries with duplicate queries.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer"