This commit is contained in:
Andrej Milićević 2026-01-20 19:42:28 +01:00 committed by GitHub
commit b63c4aef34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1229 additions and 167 deletions

View file

@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError):
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class QueryValidationError(CogneeValidationError):
def __init__(
self,
message: str = "Queries not supplied in the correct format.",
name: str = "QueryValidationError",
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -1,6 +1,9 @@
import asyncio
from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
@ -56,11 +59,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
@ -89,47 +93,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer based on the query and the
extended context.
"""
triplets = context
if triplets is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
round_idx = 1
while round_idx <= context_extension_rounds:
prev_size = len(triplets)
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_context(completion)
triplets = list(set(triplets))
context_text = await self.resolve_edges_to_text(triplets)
num_triplets = len(triplets)
if num_triplets == prev_size:
logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early."
)
break
logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets: {num_triplets}"
)
round_idx += 1
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
@ -137,6 +100,131 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
triplets_batch = context
if query:
# This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query]
if triplets_batch:
triplets_batch = [triplets_batch]
if triplets_batch is None:
triplets_batch = await self.get_context(query_batch=query_batch)
if not triplets_batch:
return []
context_text = ""
context_text_batch = await asyncio.gather(
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
)
round_idx = 1
# We store queries as keys and their associated states in this dict.
# The state is a 3-item object QueryState, which holds triplets, context text,
# and a boolean marking whether we should continue extending the context for that query.
finished_queries_states = {}
for batched_query, batched_triplets, batched_context_text in zip(
query_batch, triplets_batch, context_text_batch
):
# Populating the dict at the start with initial information.
finished_queries_states[batched_query] = QueryState(
batched_triplets, batched_context_text, False
)
while round_idx <= context_extension_rounds:
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
if all(
batched_query_state.finished_extending_context
for batched_query_state in finished_queries_states.values()
):
# We stop early only if all queries in the batch have reached their final state
logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early."
)
break
relevant_queries = [
rel_query
for rel_query in finished_queries_states.keys()
if not finished_queries_states[rel_query].finished_extending_context
]
prev_sizes = [
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
completions = await asyncio.gather(
*[
generate_completion(
query=rel_query,
context=finished_queries_states[rel_query].context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
for rel_query in relevant_queries
],
)
# Get new triplets, and merge them with existing ones, filtering out duplicates
new_triplets_batch = await self.get_context(query_batch=completions)
for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch):
finished_queries_states[rel_query].triplets = list(
dict.fromkeys(
finished_queries_states[rel_query].triplets + batched_new_triplets
)
)
# Resolve new triplets to text
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(finished_queries_states[rel_query].triplets)
for rel_query in relevant_queries
]
)
# Update context_texts in query states
for rel_query, batched_context_text in zip(relevant_queries, context_text_batch):
finished_queries_states[rel_query].context_text = batched_context_text
new_sizes = [
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes):
# Mark done queries accordingly
if prev_size == new_size:
finished_queries_states[rel_query].finished_extending_context = True
logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets for each query : {new_sizes}"
)
round_idx += 1
completion_batch = []
result_completion_batch = []
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
@ -153,18 +241,36 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
completion_batch = await asyncio.gather(
*[
generate_completion(
query=batched_query,
context=batched_query_state.context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for batched_query, batched_query_state in finished_queries_states.items()
],
)
if self.save_interaction and context_text and triplets and completion:
# Make sure answers are returned for duplicate queries, in the order they were asked.
for batched_query, batched_completion in zip(
finished_queries_states.keys(), completion_batch
):
finished_queries_states[batched_query].completion = batched_completion
for batched_query in query_batch:
result_completion_batch.append(finished_queries_states[batched_query].completion)
# TODO: Do batch queries for save interaction
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
question=query,
answer=completion_batch[0],
context=context_text_batch[0],
triplets=triplets_batch[0],
)
if session_save:
@ -175,4 +281,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return [completion]
return result_completion_batch if result_completion_batch else [completion]

View file

@ -3,6 +3,9 @@ import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -86,12 +89,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
conversation_history: str = "",
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
) -> tuple[List[Any], List[str], List[List[Edge]]]:
"""
Run chain-of-thought completion with optional structured output.
@ -109,72 +113,187 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
followup_question_batch = []
completion_batch = []
context_text_batch = []
if query:
# Treat a single query as a batch of queries, mainly avoiding massive code duplication
query_batch = [query]
if context:
context = [context]
triplets_batch = context
# dict containing query -> QueryState key-value pairs
# For every query, we save necessary data so we can execute requests in parallel
query_state_tracker = {}
for batched_query in query_batch:
query_state_tracker[batched_query] = QueryState()
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
# Get context, resolve to text, and store info in the query state
triplets_batch = await self.get_context(
query_batch=list(query_state_tracker.keys())
)
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_triplets)
for batched_triplets in triplets_batch
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_state_tracker.keys(), triplets_batch, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
else:
context_text = await self.resolve_edges_to_text(context)
# In this case just resolve to text and save to the query state
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_context)
for batched_context in context
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_state_tracker.keys(), context, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
# Find new triplets, and update existing query states
triplets_batch = await self.get_context(query_batch=followup_question_batch)
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
for batched_query, batched_followup_triplets in zip(
query_state_tracker.keys(), triplets_batch
):
query_state_tracker[batched_query].triplets = list(
dict.fromkeys(
query_state_tracker[batched_query].triplets + batched_followup_triplets
)
)
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_query_state.triplets)
for batched_query_state in query_state_tracker.values()
]
)
for batched_query, batched_context_text in zip(
query_state_tracker.keys(), context_text_batch
):
query_state_tracker[batched_query].context_text = batched_context_text
completion_batch = await asyncio.gather(
*[
generate_completion(
query=batched_query,
context=batched_query_state.context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
)
for batched_query, batched_query_state in query_state_tracker.items()
]
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
for batched_query, batched_completion in zip(
query_state_tracker.keys(), completion_batch
):
query_state_tracker[batched_query].completion = batched_completion
if round_idx == max_iter:
# When we finish all iterations:
# Make sure answers are returned for duplicate queries, in the order they were asked.
completion_batch = []
for batched_query in query_batch:
completion_batch.append(query_state_tracker[batched_query].completion)
logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
for batched_query, batched_query_state in query_state_tracker.items():
batched_query_state.answer_text = _as_answer_text(
batched_query_state.completion
)
valid_args = {
"query": batched_query,
"answer": batched_query_state.answer_text,
"context": batched_query_state.context_text,
}
batched_query_state.valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path,
context=valid_args,
)
batched_query_state.valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning_batch = await asyncio.gather(
*[
LLMGateway.acreate_structured_output(
text_input=batched_query_state.valid_user_prompt,
system_prompt=batched_query_state.valid_system_prompt,
response_model=str,
)
for batched_query_state in query_state_tracker.values()
]
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
for batched_query, batched_reasoning in zip(
query_state_tracker.keys(), reasoning_batch
):
query_state_tracker[batched_query].reasoning = batched_reasoning
for batched_query, batched_query_state in query_state_tracker.items():
followup_args = {
"query": batched_query,
"answer": batched_query_state.answer_text,
"reasoning": batched_query_state.reasoning,
}
batched_query_state.followup_prompt = render_prompt(
filename=self.followup_user_prompt_path,
context=followup_args,
)
batched_query_state.followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question_batch = await asyncio.gather(
*[
LLMGateway.acreate_structured_output(
text_input=batched_query_state.followup_prompt,
system_prompt=batched_query_state.followup_system,
response_model=str,
)
for batched_query_state in query_state_tracker.values()
]
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
for batched_query, batched_followup_question in zip(
query_state_tracker.keys(), followup_question_batch
):
query_state_tracker[batched_query].followup_question = batched_followup_question
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}"
)
return completion, context_text, triplets
return completion_batch, context_text_batch, triplets_batch
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
max_iter=4,
response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]:
"""
Generate completion responses based on a user query and contextual information.
@ -202,12 +321,26 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
# Load conversation history if enabled
conversation_history = ""
if session_save:
@ -215,17 +348,23 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
completion, context_text, triplets = await self._run_cot_completion(
query=query,
query_batch=query_batch,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
# TODO: Handle save interaction for batch queries
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
question=query,
answer=str(completion[0]),
context=context_text[0],
triplets=triplets[0],
)
# TODO: Handle session save interaction for batch queries
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)
@ -236,4 +375,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return [completion]
return completion

View file

@ -4,6 +4,8 @@ from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
@ -79,7 +81,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
"""
return await resolve_edges_to_text(retrieved_edges)
async def get_triplets(self, query: str) -> List[Edge]:
async def get_triplets(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
"""
Retrieves relevant graph triplets based on a query string.
@ -107,6 +113,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
found_triplets = await brute_force_triplet_search(
query,
query_batch,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
@ -117,7 +124,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return found_triplets
async def get_context(self, query: str) -> List[Edge]:
async def get_context(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -139,17 +150,36 @@ class GraphCompletionRetriever(BaseGraphRetriever):
logger.warning("Search attempt on an empty knowledge graph")
return []
triplets = await self.get_triplets(query)
triplets = await self.get_triplets(query, query_batch)
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
return []
if query_batch:
for batched_triplets, batched_query in zip(triplets, query_batch):
if len(batched_triplets) == 0:
logger.warning(
f"Empty context was provided to the completion for the query: {batched_query}"
)
entity_nodes_batch = []
# context = await self.resolve_edges_to_text(triplets)
for batched_triplets in triplets:
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
entity_nodes = get_entity_nodes_from_triplets(triplets)
# Remove duplicates and update node access, if it is enabled
import os
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true":
for batched_entity_nodes in entity_nodes_batch:
await update_node_access_timestamps(batched_entity_nodes)
else:
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
return []
# context = await self.resolve_edges_to_text(triplets)
entity_nodes = get_entity_nodes_from_triplets(triplets)
await update_node_access_timestamps(entity_nodes)
await update_node_access_timestamps(entity_nodes)
return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
@ -158,10 +188,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]:
"""
Generates a completion using graph connections context based on a query.
@ -180,18 +211,38 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- Any: A generated completion based on the query and context provided.
"""
triplets = context
if triplets is None:
triplets = await self.get_context(query)
context_text = await resolve_edges_to_text(triplets)
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise QueryValidationError(message=msg)
triplets = context
if triplets is None:
triplets = await self.get_context(query, query_batch)
context_text = ""
context_text_batch = []
if triplets and isinstance(triplets[0], list):
context_text_batch = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await resolve_edges_to_text(triplets)
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
@ -208,14 +259,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if query_batch and len(query_batch) > 0:
completion = await asyncio.gather(
*[
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context in zip(query_batch, context_text_batch)
],
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
@ -230,7 +296,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
session_id=session_id,
)
return [completion]
return completion if isinstance(completion, list) else [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""

View file

@ -1,5 +1,6 @@
from typing import List, Optional, Type, Union
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine
@ -146,17 +147,10 @@ async def brute_force_triplet_search(
In single-query mode, node_distances and edge_distances are stored as flat lists.
In batch mode, they are stored as list-of-lists (one list per query).
"""
if query is not None and query_batch is not None:
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
if query is None and query_batch is None:
raise ValueError("Must provide either 'query' or 'query_batch'.")
if query is not None and (not query or not isinstance(query, str)):
raise ValueError("The query must be a non-empty string.")
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
raise ValueError("query_batch must be a non-empty list of strings.")
if not all(isinstance(q, str) and q for q in query_batch):
raise ValueError("All items in query_batch must be non-empty strings.")
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")

View file

@ -0,0 +1,34 @@
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class QueryState:
"""
Helper class containing all necessary information about the query state.
Used (for now) in COT and Context Extension Retrievers to keep track of important information
in a more readable way, and enable as many parallel calls to llms as possible.
"""
def __init__(
self,
triplets: List[Edge] = None,
context_text: str = "",
finished_extending_context: bool = False,
):
# Mutual fields for COT and Context Extension
self.triplets = triplets if triplets else []
self.context_text = context_text
self.completion = ""
# Context Extension specific
self.finished_extending_context = finished_extending_context
# COT specific
self.answer_text: str = ""
self.valid_user_prompt: str = ""
self.valid_system_prompt: str = ""
self.reasoning: str = ""
self.followup_question: str = ""
self.followup_prompt: str = ""
self.followup_system: str = ""

View file

@ -0,0 +1,14 @@
def validate_queries(query, query_batch) -> tuple[bool, str]:
if query is not None and query_batch is not None:
return False, "Cannot provide both 'query' and 'query_batch'; use exactly one."
if query is None and query_batch is None:
return False, "Must provide either 'query' or 'query_batch'."
if query is not None and (not query or not isinstance(query, str)):
return False, "The query must be a non-empty string."
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
return False, "query_batch must be a non-empty list of strings."
if not all(isinstance(q, str) and q for q in query_batch):
return False, "All items in query_batch must be non-empty strings."
return True, ""

View file

@ -1,4 +1,3 @@
import os
import pytest
import pathlib
import pytest_asyncio

View file

@ -1,4 +1,5 @@
import pytest
from itertools import cycle
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
@ -81,7 +82,7 @@ async def test_get_completion_without_context(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -157,7 +158,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
side_effect=[[[mock_edge]], [[mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -194,7 +195,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -240,7 +241,7 @@ async def test_get_completion_with_session(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -304,7 +305,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -361,7 +362,7 @@ async def test_get_completion_with_response_model(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -403,7 +404,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -446,7 +447,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -467,3 +468,339 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_provided_context(mock_edge):
"""Test get_completion batch queries uses provided context."""
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
"""Test get_completion batch queries with multiple context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge):
"""Test get_completion batch queries stops early when no new triplets found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# When get_context returns same triplets, the loop should stop early
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=4,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge):
"""Test get_completion batch queries with zero context extension rounds."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=0
)
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
TestModel(answer="Test answer"),
TestModel(answer="Test answer"),
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
response_model=TestModel,
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -1,6 +1,9 @@
import os
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID
from itertools import cycle
from cognee.modules.retrieval.graph_completion_cot_retriever import (
GraphCompletionCotRetriever,
@ -79,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
@ -105,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
max_iter=1,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert completion == ["Generated answer"]
assert context_text == ["Resolved context"]
assert len(triplets) >= 1
@ -125,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -142,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
max_iter=1,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert completion == ["Generated answer"]
assert context_text == ["Resolved context"]
assert len(triplets) >= 1
@ -167,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
side_effect=[[[mock_edge]], [[mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
@ -199,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
max_iter=2,
)
assert completion == "Generated answer"
assert context_text == "Resolved context"
assert completion == ["Generated answer"]
assert context_text == ["Resolved context"]
assert len(triplets) >= 1
@ -226,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge):
max_iter=1,
)
assert completion == "Generated answer"
assert completion == ["Generated answer"]
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") == "Previous conversation"
@ -258,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge):
max_iter=1,
)
assert isinstance(completion, TestModel)
assert completion.answer == "Test answer"
assert isinstance(completion, list)
assert isinstance(completion[0], TestModel)
assert completion[0].answer == "Test answer"
@pytest.mark.asyncio
@ -284,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge):
max_iter=1,
)
assert completion == "Generated answer"
assert completion == ["Generated answer"]
# Verify conversation_history was passed as None when empty
call_kwargs = mock_generate.call_args[1]
assert call_kwargs.get("conversation_history") is None
@ -305,7 +309,7 @@ async def test_get_completion_without_context(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -315,7 +319,7 @@ async def test_get_completion_without_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
@ -396,7 +400,7 @@ async def test_get_completion_with_session(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -462,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
@ -527,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -569,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -611,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
@ -686,3 +690,166 @@ async def test_as_answer_text_with_basemodel():
assert isinstance(result, str)
assert "[Structured Response]" in result
assert "test answer" in result
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion batch queries with provided context."""
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=cycle(["validation_result", "followup_question"]),
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
max_iter=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
#
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 1"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -646,3 +646,199 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_add_data.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_context(mock_edge):
"""Test get_completion correctly handles batch queries."""
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]]
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_without_context(mock_edge):
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_response_model(mock_edge):
"""Test get_completion of batch queries with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], response_model=TestModel
)
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_empty_context(mock_edge):
"""Test get_completion with empty context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[], []],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"])
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"