This commit is contained in:
Andrej Milićević 2026-01-20 19:42:28 +01:00 committed by GitHub
commit b63c4aef34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1229 additions and 167 deletions

View file

@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError):
status_code: int = status.HTTP_404_NOT_FOUND, status_code: int = status.HTTP_404_NOT_FOUND,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class QueryValidationError(CogneeValidationError):
def __init__(
self,
message: str = "Queries not supplied in the correct format.",
name: str = "QueryValidationError",
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -1,6 +1,9 @@
import asyncio import asyncio
from typing import Optional, List, Type, Any from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
@ -56,11 +59,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
context_extension_rounds=4, context_extension_rounds=4,
response_model: Type = str, response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]: ) -> List[Any]:
""" """
Extends the context for a given query by retrieving related triplets and generating new Extends the context for a given query by retrieving related triplets and generating new
@ -89,47 +93,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer based on the query and the - List[str]: A list containing the generated answer based on the query and the
extended context. extended context.
""" """
triplets = context
if triplets is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
round_idx = 1
while round_idx <= context_extension_rounds:
prev_size = len(triplets)
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_context(completion)
triplets = list(set(triplets))
context_text = await self.resolve_edges_to_text(triplets)
num_triplets = len(triplets)
if num_triplets == prev_size:
logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early."
)
break
logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets: {num_triplets}"
)
round_idx += 1
# Check if we need to generate context summary for caching # Check if we need to generate context summary for caching
cache_config = CacheConfig() cache_config = CacheConfig()
@ -137,6 +100,131 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_id = getattr(user, "id", None) user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
triplets_batch = context
if query:
# This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query]
if triplets_batch:
triplets_batch = [triplets_batch]
if triplets_batch is None:
triplets_batch = await self.get_context(query_batch=query_batch)
if not triplets_batch:
return []
context_text = ""
context_text_batch = await asyncio.gather(
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
)
round_idx = 1
# We store queries as keys and their associated states in this dict.
# The state is a 3-item object QueryState, which holds triplets, context text,
# and a boolean marking whether we should continue extending the context for that query.
finished_queries_states = {}
for batched_query, batched_triplets, batched_context_text in zip(
query_batch, triplets_batch, context_text_batch
):
# Populating the dict at the start with initial information.
finished_queries_states[batched_query] = QueryState(
batched_triplets, batched_context_text, False
)
while round_idx <= context_extension_rounds:
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
if all(
batched_query_state.finished_extending_context
for batched_query_state in finished_queries_states.values()
):
# We stop early only if all queries in the batch have reached their final state
logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early."
)
break
relevant_queries = [
rel_query
for rel_query in finished_queries_states.keys()
if not finished_queries_states[rel_query].finished_extending_context
]
prev_sizes = [
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
completions = await asyncio.gather(
*[
generate_completion(
query=rel_query,
context=finished_queries_states[rel_query].context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
for rel_query in relevant_queries
],
)
# Get new triplets, and merge them with existing ones, filtering out duplicates
new_triplets_batch = await self.get_context(query_batch=completions)
for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch):
finished_queries_states[rel_query].triplets = list(
dict.fromkeys(
finished_queries_states[rel_query].triplets + batched_new_triplets
)
)
# Resolve new triplets to text
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(finished_queries_states[rel_query].triplets)
for rel_query in relevant_queries
]
)
# Update context_texts in query states
for rel_query, batched_context_text in zip(relevant_queries, context_text_batch):
finished_queries_states[rel_query].context_text = batched_context_text
new_sizes = [
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes):
# Mark done queries accordingly
if prev_size == new_size:
finished_queries_states[rel_query].finished_extending_context = True
logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets for each query : {new_sizes}"
)
round_idx += 1
completion_batch = []
result_completion_batch = []
if session_save: if session_save:
conversation_history = await get_conversation_history(session_id=session_id) conversation_history = await get_conversation_history(session_id=session_id)
@ -153,18 +241,36 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
), ),
) )
else: else:
completion = await generate_completion( completion_batch = await asyncio.gather(
query=query, *[
context=context_text, generate_completion(
user_prompt_path=self.user_prompt_path, query=batched_query,
system_prompt_path=self.system_prompt_path, context=batched_query_state.context_text,
system_prompt=self.system_prompt, user_prompt_path=self.user_prompt_path,
response_model=response_model, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for batched_query, batched_query_state in finished_queries_states.items()
],
) )
if self.save_interaction and context_text and triplets and completion: # Make sure answers are returned for duplicate queries, in the order they were asked.
for batched_query, batched_completion in zip(
finished_queries_states.keys(), completion_batch
):
finished_queries_states[batched_query].completion = batched_completion
for batched_query in query_batch:
result_completion_batch.append(finished_queries_states[batched_query].completion)
# TODO: Do batch queries for save interaction
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
await self.save_qa( await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets question=query,
answer=completion_batch[0],
context=context_text_batch[0],
triplets=triplets_batch[0],
) )
if session_save: if session_save:
@ -175,4 +281,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_id=session_id, session_id=session_id,
) )
return [completion] return result_completion_batch if result_completion_batch else [completion]

View file

@ -3,6 +3,9 @@ import json
from typing import Optional, List, Type, Any from typing import Optional, List, Type, Any
from pydantic import BaseModel from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -86,12 +89,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def _run_cot_completion( async def _run_cot_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
conversation_history: str = "", conversation_history: str = "",
max_iter: int = 4, max_iter: int = 4,
response_model: Type = str, response_model: Type = str,
) -> tuple[Any, str, List[Edge]]: ) -> tuple[List[Any], List[str], List[List[Edge]]]:
""" """
Run chain-of-thought completion with optional structured output. Run chain-of-thought completion with optional structured output.
@ -109,72 +113,187 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- context_text: The resolved context text - context_text: The resolved context text
- triplets: The list of triplets used - triplets: The list of triplets used
""" """
followup_question = "" followup_question_batch = []
triplets = [] completion_batch = []
completion = "" context_text_batch = []
if query:
# Treat a single query as a batch of queries, mainly avoiding massive code duplication
query_batch = [query]
if context:
context = [context]
triplets_batch = context
# dict containing query -> QueryState key-value pairs
# For every query, we save necessary data so we can execute requests in parallel
query_state_tracker = {}
for batched_query in query_batch:
query_state_tracker[batched_query] = QueryState()
for round_idx in range(max_iter + 1): for round_idx in range(max_iter + 1):
if round_idx == 0: if round_idx == 0:
if context is None: if context is None:
triplets = await self.get_context(query) # Get context, resolve to text, and store info in the query state
context_text = await self.resolve_edges_to_text(triplets) triplets_batch = await self.get_context(
query_batch=list(query_state_tracker.keys())
)
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_triplets)
for batched_triplets in triplets_batch
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_state_tracker.keys(), triplets_batch, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
else: else:
context_text = await self.resolve_edges_to_text(context) # In this case just resolve to text and save to the query state
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_context)
for batched_context in context
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_state_tracker.keys(), context, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
else: else:
triplets += await self.get_context(followup_question) # Find new triplets, and update existing query states
context_text = await self.resolve_edges_to_text(list(set(triplets))) triplets_batch = await self.get_context(query_batch=followup_question_batch)
completion = await generate_completion( for batched_query, batched_followup_triplets in zip(
query=query, query_state_tracker.keys(), triplets_batch
context=context_text, ):
user_prompt_path=self.user_prompt_path, query_state_tracker[batched_query].triplets = list(
system_prompt_path=self.system_prompt_path, dict.fromkeys(
system_prompt=self.system_prompt, query_state_tracker[batched_query].triplets + batched_followup_triplets
conversation_history=conversation_history if conversation_history else None, )
response_model=response_model, )
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_query_state.triplets)
for batched_query_state in query_state_tracker.values()
]
)
for batched_query, batched_context_text in zip(
query_state_tracker.keys(), context_text_batch
):
query_state_tracker[batched_query].context_text = batched_context_text
completion_batch = await asyncio.gather(
*[
generate_completion(
query=batched_query,
context=batched_query_state.context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
)
for batched_query, batched_query_state in query_state_tracker.items()
]
) )
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") for batched_query, batched_completion in zip(
query_state_tracker.keys(), completion_batch
):
query_state_tracker[batched_query].completion = batched_completion
if round_idx == max_iter:
# When we finish all iterations:
# Make sure answers are returned for duplicate queries, in the order they were asked.
completion_batch = []
for batched_query in query_batch:
completion_batch.append(query_state_tracker[batched_query].completion)
logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}")
if round_idx < max_iter: if round_idx < max_iter:
answer_text = _as_answer_text(completion) for batched_query, batched_query_state in query_state_tracker.items():
valid_args = {"query": query, "answer": answer_text, "context": context_text} batched_query_state.answer_text = _as_answer_text(
valid_user_prompt = render_prompt( batched_query_state.completion
filename=self.validation_user_prompt_path, context=valid_args )
) valid_args = {
valid_system_prompt = read_query_prompt( "query": batched_query,
prompt_file_name=self.validation_system_prompt_path "answer": batched_query_state.answer_text,
"context": batched_query_state.context_text,
}
batched_query_state.valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path,
context=valid_args,
)
batched_query_state.valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning_batch = await asyncio.gather(
*[
LLMGateway.acreate_structured_output(
text_input=batched_query_state.valid_user_prompt,
system_prompt=batched_query_state.valid_system_prompt,
response_model=str,
)
for batched_query_state in query_state_tracker.values()
]
) )
reasoning = await LLMGateway.acreate_structured_output( for batched_query, batched_reasoning in zip(
text_input=valid_user_prompt, query_state_tracker.keys(), reasoning_batch
system_prompt=valid_system_prompt, ):
response_model=str, query_state_tracker[batched_query].reasoning = batched_reasoning
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} for batched_query, batched_query_state in query_state_tracker.items():
followup_prompt = render_prompt( followup_args = {
filename=self.followup_user_prompt_path, context=followup_args "query": batched_query,
) "answer": batched_query_state.answer_text,
followup_system = read_query_prompt( "reasoning": batched_query_state.reasoning,
prompt_file_name=self.followup_system_prompt_path }
batched_query_state.followup_prompt = render_prompt(
filename=self.followup_user_prompt_path,
context=followup_args,
)
batched_query_state.followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question_batch = await asyncio.gather(
*[
LLMGateway.acreate_structured_output(
text_input=batched_query_state.followup_prompt,
system_prompt=batched_query_state.followup_system,
response_model=str,
)
for batched_query_state in query_state_tracker.values()
]
) )
followup_question = await LLMGateway.acreate_structured_output( for batched_query, batched_followup_question in zip(
text_input=followup_prompt, system_prompt=followup_system, response_model=str query_state_tracker.keys(), followup_question_batch
) ):
query_state_tracker[batched_query].followup_question = batched_followup_question
logger.info( logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}"
) )
return completion, context_text, triplets return completion_batch, context_text_batch, triplets_batch
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
max_iter=4, max_iter=4,
response_model: Type = str, response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]: ) -> List[Any]:
""" """
Generate completion responses based on a user query and contextual information. Generate completion responses based on a user query and contextual information.
@ -202,12 +321,26 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query. - List[str]: A list containing the generated answer to the user's query.
""" """
# Check if session saving is enabled # Check if session saving is enabled
cache_config = CacheConfig() cache_config = CacheConfig()
user = session_user.get() user = session_user.get()
user_id = getattr(user, "id", None) user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
# Load conversation history if enabled # Load conversation history if enabled
conversation_history = "" conversation_history = ""
if session_save: if session_save:
@ -215,17 +348,23 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
completion, context_text, triplets = await self._run_cot_completion( completion, context_text, triplets = await self._run_cot_completion(
query=query, query=query,
query_batch=query_batch,
context=context, context=context,
conversation_history=conversation_history, conversation_history=conversation_history,
max_iter=max_iter, max_iter=max_iter,
response_model=response_model, response_model=response_model,
) )
# TODO: Handle save interaction for batch queries
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
await self.save_qa( await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets question=query,
answer=str(completion[0]),
context=context_text[0],
triplets=triplets[0],
) )
# TODO: Handle session save interaction for batch queries
# Save to session cache if enabled # Save to session cache if enabled
if session_save: if session_save:
context_summary = await summarize_text(context_text) context_summary = await summarize_text(context_text)
@ -236,4 +375,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id, session_id=session_id,
) )
return [completion] return completion

View file

@ -4,6 +4,8 @@ from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
@ -79,7 +81,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
""" """
return await resolve_edges_to_text(retrieved_edges) return await resolve_edges_to_text(retrieved_edges)
async def get_triplets(self, query: str) -> List[Edge]: async def get_triplets(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
""" """
Retrieves relevant graph triplets based on a query string. Retrieves relevant graph triplets based on a query string.
@ -107,6 +113,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
found_triplets = await brute_force_triplet_search( found_triplets = await brute_force_triplet_search(
query, query,
query_batch,
top_k=self.top_k, top_k=self.top_k,
collections=vector_index_collections or None, collections=vector_index_collections or None,
node_type=self.node_type, node_type=self.node_type,
@ -117,7 +124,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return found_triplets return found_triplets
async def get_context(self, query: str) -> List[Edge]: async def get_context(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
""" """
Retrieves and resolves graph triplets into context based on a query. Retrieves and resolves graph triplets into context based on a query.
@ -139,17 +150,36 @@ class GraphCompletionRetriever(BaseGraphRetriever):
logger.warning("Search attempt on an empty knowledge graph") logger.warning("Search attempt on an empty knowledge graph")
return [] return []
triplets = await self.get_triplets(query) triplets = await self.get_triplets(query, query_batch)
if len(triplets) == 0: if query_batch:
logger.warning("Empty context was provided to the completion") for batched_triplets, batched_query in zip(triplets, query_batch):
return [] if len(batched_triplets) == 0:
logger.warning(
f"Empty context was provided to the completion for the query: {batched_query}"
)
entity_nodes_batch = []
# context = await self.resolve_edges_to_text(triplets) for batched_triplets in triplets:
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
entity_nodes = get_entity_nodes_from_triplets(triplets) # Remove duplicates and update node access, if it is enabled
import os
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true":
for batched_entity_nodes in entity_nodes_batch:
await update_node_access_timestamps(batched_entity_nodes)
else:
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
return []
# context = await self.resolve_edges_to_text(triplets)
entity_nodes = get_entity_nodes_from_triplets(triplets)
await update_node_access_timestamps(entity_nodes)
await update_node_access_timestamps(entity_nodes)
return triplets return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
@ -158,10 +188,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
response_model: Type = str, response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]: ) -> List[Any]:
""" """
Generates a completion using graph connections context based on a query. Generates a completion using graph connections context based on a query.
@ -180,18 +211,38 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- Any: A generated completion based on the query and context provided. - Any: A generated completion based on the query and context provided.
""" """
triplets = context
if triplets is None:
triplets = await self.get_context(query)
context_text = await resolve_edges_to_text(triplets)
cache_config = CacheConfig() cache_config = CacheConfig()
user = session_user.get() user = session_user.get()
user_id = getattr(user, "id", None) user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
triplets = context
if triplets is None:
triplets = await self.get_context(query, query_batch)
context_text = ""
context_text_batch = []
if triplets and isinstance(triplets[0], list):
context_text_batch = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await resolve_edges_to_text(triplets)
if session_save: if session_save:
conversation_history = await get_conversation_history(session_id=session_id) conversation_history = await get_conversation_history(session_id=session_id)
@ -208,14 +259,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
), ),
) )
else: else:
completion = await generate_completion( if query_batch and len(query_batch) > 0:
query=query, completion = await asyncio.gather(
context=context_text, *[
user_prompt_path=self.user_prompt_path, generate_completion(
system_prompt_path=self.system_prompt_path, query=query,
system_prompt=self.system_prompt, context=context,
response_model=response_model, user_prompt_path=self.user_prompt_path,
) system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context in zip(query_batch, context_text_batch)
],
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
await self.save_qa( await self.save_qa(
@ -230,7 +296,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
session_id=session_id, session_id=session_id,
) )
return [completion] return completion if isinstance(completion, list) else [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
""" """

View file

@ -1,5 +1,6 @@
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
@ -146,17 +147,10 @@ async def brute_force_triplet_search(
In single-query mode, node_distances and edge_distances are stored as flat lists. In single-query mode, node_distances and edge_distances are stored as flat lists.
In batch mode, they are stored as list-of-lists (one list per query). In batch mode, they are stored as list-of-lists (one list per query).
""" """
if query is not None and query_batch is not None: is_query_valid, msg = validate_queries(query, query_batch)
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") if not is_query_valid:
if query is None and query_batch is None: raise ValueError(msg)
raise ValueError("Must provide either 'query' or 'query_batch'.")
if query is not None and (not query or not isinstance(query, str)):
raise ValueError("The query must be a non-empty string.")
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
raise ValueError("query_batch must be a non-empty list of strings.")
if not all(isinstance(q, str) and q for q in query_batch):
raise ValueError("All items in query_batch must be non-empty strings.")
if top_k <= 0: if top_k <= 0:
raise ValueError("top_k must be a positive integer.") raise ValueError("top_k must be a positive integer.")

View file

@ -0,0 +1,34 @@
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class QueryState:
"""
Helper class containing all necessary information about the query state.
Used (for now) in COT and Context Extension Retrievers to keep track of important information
in a more readable way, and enable as many parallel calls to llms as possible.
"""
def __init__(
self,
triplets: List[Edge] = None,
context_text: str = "",
finished_extending_context: bool = False,
):
# Mutual fields for COT and Context Extension
self.triplets = triplets if triplets else []
self.context_text = context_text
self.completion = ""
# Context Extension specific
self.finished_extending_context = finished_extending_context
# COT specific
self.answer_text: str = ""
self.valid_user_prompt: str = ""
self.valid_system_prompt: str = ""
self.reasoning: str = ""
self.followup_question: str = ""
self.followup_prompt: str = ""
self.followup_system: str = ""

View file

@ -0,0 +1,14 @@
def validate_queries(query, query_batch) -> tuple[bool, str]:
if query is not None and query_batch is not None:
return False, "Cannot provide both 'query' and 'query_batch'; use exactly one."
if query is None and query_batch is None:
return False, "Must provide either 'query' or 'query_batch'."
if query is not None and (not query or not isinstance(query, str)):
return False, "The query must be a non-empty string."
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
return False, "query_batch must be a non-empty list of strings."
if not all(isinstance(q, str) and q for q in query_batch):
return False, "All items in query_batch must be non-empty strings."
return True, ""

View file

@ -1,4 +1,3 @@
import os
import pytest import pytest
import pathlib import pathlib
import pytest_asyncio import pytest_asyncio

View file

@ -1,4 +1,5 @@
import pytest import pytest
from itertools import cycle
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID from uuid import UUID
@ -81,7 +82,7 @@ async def test_get_completion_without_context(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -157,7 +158,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
retriever, retriever,
"get_context", "get_context",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]], side_effect=[[[mock_edge]], [[mock_edge2]]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -194,7 +195,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
retriever = GraphCompletionContextExtensionRetriever() retriever = GraphCompletionContextExtensionRetriever()
with ( with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -240,7 +241,7 @@ async def test_get_completion_with_session(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -304,7 +305,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -361,7 +362,7 @@ async def test_get_completion_with_response_model(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -403,7 +404,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -446,7 +447,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -467,3 +468,339 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
assert isinstance(completion, list) assert isinstance(completion, list)
assert len(completion) == 1 assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_provided_context(mock_edge):
"""Test get_completion batch queries uses provided context."""
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
"""Test get_completion batch queries with multiple context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge):
"""Test get_completion batch queries stops early when no new triplets found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# When get_context returns same triplets, the loop should stop early
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=4,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge):
"""Test get_completion batch queries with zero context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=0
)
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
TestModel(answer="Test answer"),
TestModel(answer="Test answer"),
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
response_model=TestModel,
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -1,6 +1,9 @@
import os
import pytest import pytest
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID from uuid import UUID
from itertools import cycle
from cognee.modules.retrieval.graph_completion_cot_retriever import ( from cognee.modules.retrieval.graph_completion_cot_retriever import (
GraphCompletionCotRetriever, GraphCompletionCotRetriever,
@ -79,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer", return_value="Generated answer",
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer", return_value="Generated answer",
@ -105,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
max_iter=1, max_iter=1,
) )
assert completion == "Generated answer" assert completion == ["Generated answer"]
assert context_text == "Resolved context" assert context_text == ["Resolved context"]
assert len(triplets) >= 1 assert len(triplets) >= 1
@ -125,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -142,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
max_iter=1, max_iter=1,
) )
assert completion == "Generated answer" assert completion == ["Generated answer"]
assert context_text == "Resolved context" assert context_text == ["Resolved context"]
assert len(triplets) >= 1 assert len(triplets) >= 1
@ -167,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
retriever, retriever,
"get_context", "get_context",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]], side_effect=[[[mock_edge]], [[mock_edge2]]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
@ -199,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
max_iter=2, max_iter=2,
) )
assert completion == "Generated answer" assert completion == ["Generated answer"]
assert context_text == "Resolved context" assert context_text == ["Resolved context"]
assert len(triplets) >= 1 assert len(triplets) >= 1
@ -226,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge):
max_iter=1, max_iter=1,
) )
assert completion == "Generated answer" assert completion == ["Generated answer"]
call_kwargs = mock_generate.call_args[1] call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") == "Previous conversation" assert call_kwargs.get("conversation_history") == "Previous conversation"
@ -258,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge):
max_iter=1, max_iter=1,
) )
assert isinstance(completion, TestModel) assert isinstance(completion, list)
assert completion.answer == "Test answer" assert isinstance(completion[0], TestModel)
assert completion[0].answer == "Test answer"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -284,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge):
max_iter=1, max_iter=1,
) )
assert completion == "Generated answer" assert completion == ["Generated answer"]
# Verify conversation_history was passed as None when empty # Verify conversation_history was passed as None when empty
call_kwargs = mock_generate.call_args[1] call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") is None assert call_kwargs.get("conversation_history") is None
@ -305,7 +309,7 @@ async def test_get_completion_without_context(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -315,7 +319,7 @@ async def test_get_completion_without_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer", return_value="Generated answer",
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer", return_value="Generated answer",
@ -396,7 +400,7 @@ async def test_get_completion_with_session(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -462,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer", return_value="Generated answer",
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer", return_value="Generated answer",
@ -527,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -569,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -611,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer", return_value="Generated answer",
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer", return_value="Generated answer",
@ -686,3 +690,166 @@ async def test_as_answer_text_with_basemodel():
assert isinstance(result, str) assert isinstance(result, str)
assert "[Structured Response]" in result assert "[Structured Response]" in result
assert "test answer" in result assert "test answer" in result
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion batch queries with provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=cycle(["validation_result", "followup_question"]),
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
max_iter=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
#
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 1"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -646,3 +646,199 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge
assert len(completion) == 1 assert len(completion) == 1
assert completion[0] == "Generated answer" assert completion[0] == "Generated answer"
mock_add_data.assert_awaited_once() mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion correctly handles batch queries."""
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]]
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_empty_context(mock_edge):
"""Test get_completion with empty context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[], []],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"