fix: fix context extension and cot retrievers
This commit is contained in:
parent
8e0d112439
commit
8c7b309199
6 changed files with 286 additions and 166 deletions
|
|
@ -15,6 +15,19 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class QueryState:
|
||||||
|
"""
|
||||||
|
Helper class containing all necessary information about the query state:
|
||||||
|
the triplets and context associated with it, and also a check whether
|
||||||
|
it has fully extended the context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool):
|
||||||
|
self.triplets = triplets
|
||||||
|
self.context_text = context_text
|
||||||
|
self.finished_extending_context = finished_extending_context
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
"""
|
"""
|
||||||
Handles graph context completion for question answering tasks, extending context based
|
Handles graph context completion for question answering tasks, extending context based
|
||||||
|
|
@ -91,10 +104,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
- List[str]: A list containing the generated answer based on the query and the
|
- List[str]: A list containing the generated answer based on the query and the
|
||||||
extended context.
|
extended context.
|
||||||
"""
|
"""
|
||||||
# TODO: This may be unnecessary in this retriever, will check later
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
query_validation = validate_queries(query, query_batch)
|
if not is_query_valid:
|
||||||
if not query_validation[0]:
|
raise ValueError(msg)
|
||||||
raise ValueError(query_validation[1])
|
|
||||||
|
|
||||||
triplets_batch = context
|
triplets_batch = context
|
||||||
|
|
||||||
|
|
@ -117,70 +129,87 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
round_idx = 1
|
round_idx = 1
|
||||||
|
|
||||||
# We will be removing queries, and their associated triplets and context, as we go
|
# We store queries as keys and their associated states in this dict.
|
||||||
# through iterations, so we need to save their final states for the final generation.
|
# The state is a 3-item object QueryState, which holds triplets, context text,
|
||||||
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
# and a boolean marking whether we should continue extending the context for that query.
|
||||||
original_query_batch = query_batch
|
finished_queries_states = {}
|
||||||
finished_queries_data = {}
|
|
||||||
for i, batched_query in enumerate(query_batch):
|
for batched_query, batched_triplets, batched_context_text in zip(
|
||||||
if not triplets_batch[i]:
|
query_batch, triplets_batch, context_text_batch
|
||||||
query_batch[i] = ""
|
):
|
||||||
else:
|
# Populating the dict at the start with initial information.
|
||||||
finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i])
|
finished_queries_states[batched_query] = QueryState(
|
||||||
|
batched_triplets, batched_context_text, False
|
||||||
|
)
|
||||||
|
|
||||||
while round_idx <= context_extension_rounds:
|
while round_idx <= context_extension_rounds:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out the queries that cannot be extended further, and their associated contexts
|
if all(
|
||||||
query_batch = [query for query in query_batch if query]
|
batched_query_state.finished_extending_context
|
||||||
triplets_batch = [triplets for triplets in triplets_batch if triplets]
|
for batched_query_state in finished_queries_states.values()
|
||||||
context_text_batch = [
|
):
|
||||||
context_text for context_text in context_text_batch if context_text
|
# We stop early only if all queries in the batch have reached their final state
|
||||||
]
|
|
||||||
if len(query_batch) == 0:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
prev_sizes = [len(triplets) for triplets in triplets_batch]
|
prev_sizes = [
|
||||||
|
len(batched_query_state.triplets)
|
||||||
|
for batched_query_state in finished_queries_states.values()
|
||||||
|
if not batched_query_state.finished_extending_context
|
||||||
|
]
|
||||||
|
|
||||||
completions = await asyncio.gather(
|
completions = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
generate_completion(
|
generate_completion(
|
||||||
query=query,
|
query=batched_query,
|
||||||
context=context,
|
context=batched_query_state.context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
)
|
)
|
||||||
for query, context in zip(query_batch, context_text_batch)
|
for batched_query, batched_query_state in finished_queries_states.items()
|
||||||
|
if not batched_query_state.finished_extending_context
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
||||||
new_triplets_batch = await self.get_context(query_batch=completions)
|
new_triplets_batch = await self.get_context(query_batch=completions)
|
||||||
for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)):
|
for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch):
|
||||||
triplets += new_triplets
|
finished_queries_states[batched_query].triplets = list(
|
||||||
triplets_batch[i] = list(dict.fromkeys(triplets))
|
dict.fromkeys(
|
||||||
|
finished_queries_states[batched_query].triplets + batched_new_triplets
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve new triplets to text
|
||||||
context_text_batch = await asyncio.gather(
|
context_text_batch = await asyncio.gather(
|
||||||
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
*[
|
||||||
|
self.resolve_edges_to_text(batched_query_state.triplets)
|
||||||
|
for batched_query_state in finished_queries_states.values()
|
||||||
|
if not batched_query_state.finished_extending_context
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
new_sizes = [len(triplets) for triplets in triplets_batch]
|
# Update context_texts in query states
|
||||||
|
for batched_query, batched_context_text in zip(query_batch, context_text_batch):
|
||||||
|
if not finished_queries_states[batched_query].finished_extending_context:
|
||||||
|
finished_queries_states[batched_query].context_text = batched_context_text
|
||||||
|
|
||||||
for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate(
|
new_sizes = [
|
||||||
zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch)
|
len(batched_query_state.triplets)
|
||||||
):
|
for batched_query_state in finished_queries_states.values()
|
||||||
finished_queries_data[query] = (triplets, context_text)
|
if not batched_query_state.finished_extending_context
|
||||||
|
]
|
||||||
|
|
||||||
|
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
|
||||||
|
# Mark done queries accordingly
|
||||||
if prev_size == new_size:
|
if prev_size == new_size:
|
||||||
# In this case, we can stop trying to extend the context of this query
|
finished_queries_states[batched_query].finished_extending_context = True
|
||||||
query_batch[i] = ""
|
|
||||||
triplets_batch[i] = []
|
|
||||||
context_text_batch[i] = ""
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} - "
|
f"Context extension: round {round_idx} - "
|
||||||
|
|
@ -189,15 +218,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
round_idx += 1
|
round_idx += 1
|
||||||
|
|
||||||
# Reset variables for the final generations. They contain the final state
|
|
||||||
# of triplets and contexts for each query, after all extension iterations.
|
|
||||||
query_batch = original_query_batch
|
|
||||||
triplets_batch = []
|
|
||||||
context_text_batch = []
|
|
||||||
for batched_query in query_batch:
|
|
||||||
triplets_batch.append(finished_queries_data[batched_query][0])
|
|
||||||
context_text_batch.append(finished_queries_data[batched_query][1])
|
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
# Check if we need to generate context summary for caching
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
user = session_user.get()
|
user = session_user.get()
|
||||||
|
|
@ -226,13 +246,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
*[
|
*[
|
||||||
generate_completion(
|
generate_completion(
|
||||||
query=batched_query,
|
query=batched_query,
|
||||||
context=batched_context_text,
|
context=batched_query_state.context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
for batched_query, batched_context_text in zip(query_batch, context_text_batch)
|
for batched_query, batched_query_state in finished_queries_states.items()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,27 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class QueryState:
|
||||||
|
"""
|
||||||
|
Helper class containing all necessary information about the query state.
|
||||||
|
Used to keep track of important information in a more readable way, and
|
||||||
|
enable as many parallel calls to llms as possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
completion: str = ""
|
||||||
|
triplets: List[Edge] = []
|
||||||
|
context_text: str = ""
|
||||||
|
|
||||||
|
answer_text: str = ""
|
||||||
|
valid_user_prompt: str = ""
|
||||||
|
valid_system_prompt: str = ""
|
||||||
|
reasoning: str = ""
|
||||||
|
|
||||||
|
followup_question: str = ""
|
||||||
|
followup_prompt: str = ""
|
||||||
|
followup_system: str = ""
|
||||||
|
|
||||||
|
|
||||||
def _as_answer_text(completion: Any) -> str:
|
def _as_answer_text(completion: Any) -> str:
|
||||||
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
||||||
if isinstance(completion, str):
|
if isinstance(completion, str):
|
||||||
|
|
@ -87,12 +108,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
async def _run_cot_completion(
|
async def _run_cot_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: Optional[str] = None,
|
||||||
context: Optional[List[Edge]] = None,
|
query_batch: Optional[List[str]] = None,
|
||||||
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||||
conversation_history: str = "",
|
conversation_history: str = "",
|
||||||
max_iter: int = 4,
|
max_iter: int = 4,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
) -> tuple[Any, str, List[Edge]]:
|
) -> tuple[List[Any], List[str], List[List[Edge]]]:
|
||||||
"""
|
"""
|
||||||
Run chain-of-thought completion with optional structured output.
|
Run chain-of-thought completion with optional structured output.
|
||||||
|
|
||||||
|
|
@ -110,64 +132,158 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
- context_text: The resolved context text
|
- context_text: The resolved context text
|
||||||
- triplets: The list of triplets used
|
- triplets: The list of triplets used
|
||||||
"""
|
"""
|
||||||
followup_question = ""
|
followup_question_batch = []
|
||||||
triplets = []
|
completion_batch = []
|
||||||
completion = ""
|
context_text_batch = []
|
||||||
|
|
||||||
|
if query:
|
||||||
|
# Treat a single query as a batch of queries, mainly avoiding massive code duplication
|
||||||
|
query_batch = [query]
|
||||||
|
if context:
|
||||||
|
context = [context]
|
||||||
|
|
||||||
|
triplets_batch = context
|
||||||
|
|
||||||
|
# dict containing query -> QueryState key-value pairs
|
||||||
|
# For every query, we save necessary data so we can execute requests in parallel
|
||||||
|
query_state_tracker = {}
|
||||||
|
for batched_query in query_batch:
|
||||||
|
query_state_tracker[batched_query] = QueryState()
|
||||||
|
|
||||||
for round_idx in range(max_iter + 1):
|
for round_idx in range(max_iter + 1):
|
||||||
if round_idx == 0:
|
if round_idx == 0:
|
||||||
if context is None:
|
if context is None:
|
||||||
triplets = await self.get_context(query)
|
# Get context, resolve to text, and store info in the query state
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
triplets_batch = await self.get_context(query_batch=query_batch)
|
||||||
|
context_text_batch = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.resolve_edges_to_text(batched_triplets)
|
||||||
|
for batched_triplets in triplets_batch
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for batched_query, batched_triplets, batched_context_text in zip(
|
||||||
|
query_batch, triplets_batch, context_text_batch
|
||||||
|
):
|
||||||
|
query_state_tracker[batched_query].triplets = batched_triplets
|
||||||
|
query_state_tracker[batched_query].context_text = batched_context_text
|
||||||
else:
|
else:
|
||||||
context_text = await self.resolve_edges_to_text(context)
|
# In this case just resolve to text and save to the query state
|
||||||
|
context_text_batch = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.resolve_edges_to_text(batched_context)
|
||||||
|
for batched_context in context
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for batched_query, batched_triplets, batched_context_text in zip(
|
||||||
|
query_batch, context, context_text_batch
|
||||||
|
):
|
||||||
|
query_state_tracker[batched_query].triplets = batched_triplets
|
||||||
|
query_state_tracker[batched_query].context_text = batched_context_text
|
||||||
else:
|
else:
|
||||||
triplets += await self.get_context(followup_question)
|
# Find new triplets, and update existing query states
|
||||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
followup_triplets_batch = await self.get_context(
|
||||||
|
query_batch=followup_question_batch
|
||||||
|
)
|
||||||
|
for batched_query, batched_followup_triplets in zip(
|
||||||
|
query_batch, followup_triplets_batch
|
||||||
|
):
|
||||||
|
query_state_tracker[batched_query].triplets = list(
|
||||||
|
dict.fromkeys(
|
||||||
|
query_state_tracker[batched_query].triplets + batched_followup_triplets
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
completion = await generate_completion(
|
context_text_batch = await asyncio.gather(
|
||||||
query=query,
|
*[
|
||||||
context=context_text,
|
self.resolve_edges_to_text(batched_query_state.triplets)
|
||||||
user_prompt_path=self.user_prompt_path,
|
for batched_query_state in query_state_tracker.values()
|
||||||
system_prompt_path=self.system_prompt_path,
|
]
|
||||||
system_prompt=self.system_prompt,
|
)
|
||||||
conversation_history=conversation_history if conversation_history else None,
|
|
||||||
response_model=response_model,
|
for batched_query, batched_context_text in zip(query_batch, context_text_batch):
|
||||||
|
query_state_tracker[batched_query].context_text = batched_context_text
|
||||||
|
|
||||||
|
completion_batch = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
generate_completion(
|
||||||
|
query=batched_query,
|
||||||
|
context=batched_query_state.context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
conversation_history=conversation_history if conversation_history else None,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
for batched_query, batched_query_state in query_state_tracker.items()
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
for batched_query, batched_completion in zip(query_batch, completion_batch):
|
||||||
|
query_state_tracker[batched_query].completion = batched_completion
|
||||||
|
|
||||||
|
logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}")
|
||||||
|
|
||||||
if round_idx < max_iter:
|
if round_idx < max_iter:
|
||||||
answer_text = _as_answer_text(completion)
|
for batched_query, batched_query_state in query_state_tracker.items():
|
||||||
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
batched_query_state.answer_text = _as_answer_text(
|
||||||
valid_user_prompt = render_prompt(
|
batched_query_state.completion
|
||||||
filename=self.validation_user_prompt_path, context=valid_args
|
)
|
||||||
)
|
valid_args = {
|
||||||
valid_system_prompt = read_query_prompt(
|
"query": batched_query,
|
||||||
prompt_file_name=self.validation_system_prompt_path
|
"answer": batched_query_state.answer_text,
|
||||||
|
"context": batched_query_state.context_text,
|
||||||
|
}
|
||||||
|
batched_query_state.valid_user_prompt = render_prompt(
|
||||||
|
filename=self.validation_user_prompt_path,
|
||||||
|
context=valid_args,
|
||||||
|
)
|
||||||
|
batched_query_state.valid_system_prompt = read_query_prompt(
|
||||||
|
prompt_file_name=self.validation_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_batch = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
LLMGateway.acreate_structured_output(
|
||||||
|
text_input=batched_query_state.valid_user_prompt,
|
||||||
|
system_prompt=batched_query_state.valid_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
for batched_query_state in query_state_tracker.values()
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
reasoning = await LLMGateway.acreate_structured_output(
|
for batched_query, batched_reasoning in zip(query_batch, reasoning_batch):
|
||||||
text_input=valid_user_prompt,
|
query_state_tracker[batched_query].reasoning = batched_reasoning
|
||||||
system_prompt=valid_system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
|
||||||
followup_prompt = render_prompt(
|
|
||||||
filename=self.followup_user_prompt_path, context=followup_args
|
|
||||||
)
|
|
||||||
followup_system = read_query_prompt(
|
|
||||||
prompt_file_name=self.followup_system_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
followup_question = await LLMGateway.acreate_structured_output(
|
for batched_query, batched_query_state in query_state_tracker.items():
|
||||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
followup_args = {
|
||||||
|
"query": query,
|
||||||
|
"answer": batched_query_state.answer_text,
|
||||||
|
"reasoning": batched_query_state.reasoning,
|
||||||
|
}
|
||||||
|
batched_query_state.followup_prompt = render_prompt(
|
||||||
|
filename=self.followup_user_prompt_path,
|
||||||
|
context=followup_args,
|
||||||
|
)
|
||||||
|
batched_query_state.followup_system = read_query_prompt(
|
||||||
|
prompt_file_name=self.followup_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
followup_question_batch = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
LLMGateway.acreate_structured_output(
|
||||||
|
text_input=batched_query_state.followup_prompt,
|
||||||
|
system_prompt=batched_query_state.followup_system,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
for batched_query_state in query_state_tracker.values()
|
||||||
|
]
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion, context_text, triplets
|
return completion_batch, context_text_batch, triplets_batch
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -204,9 +320,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
query_validation = validate_queries(query, query_batch)
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
if not query_validation[0]:
|
if not is_query_valid:
|
||||||
raise ValueError(query_validation[1])
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Check if session saving is enabled
|
# Check if session saving is enabled
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
|
|
@ -219,40 +335,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
if session_save:
|
if session_save:
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
context_batch = context
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
completion_batch = []
|
query=query,
|
||||||
if query_batch and len(query_batch) > 0:
|
query_batch=query_batch,
|
||||||
if not context_batch:
|
context=context,
|
||||||
# Having a list is necessary to zip through it
|
conversation_history=conversation_history,
|
||||||
context_batch = []
|
max_iter=max_iter,
|
||||||
for _ in query_batch:
|
response_model=response_model,
|
||||||
context_batch.append(None)
|
)
|
||||||
|
|
||||||
completion_batch = await asyncio.gather(
|
|
||||||
*[
|
|
||||||
self._run_cot_completion(
|
|
||||||
query=query,
|
|
||||||
context=context,
|
|
||||||
conversation_history=conversation_history,
|
|
||||||
max_iter=max_iter,
|
|
||||||
response_model=response_model,
|
|
||||||
)
|
|
||||||
for batched_query, context in zip(query_batch, context_batch)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
completion, context_text, triplets = await self._run_cot_completion(
|
|
||||||
query=query,
|
|
||||||
context=context,
|
|
||||||
conversation_history=conversation_history,
|
|
||||||
max_iter=max_iter,
|
|
||||||
response_model=response_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Handle save interaction for batch queries
|
# TODO: Handle save interaction for batch queries
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=str(completion), context=context_text, triplets=triplets
|
question=query,
|
||||||
|
answer=str(completion[0]),
|
||||||
|
context=context_text[0],
|
||||||
|
triplets=triplets[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Handle session save interaction for batch queries
|
# TODO: Handle session save interaction for batch queries
|
||||||
|
|
@ -266,7 +364,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if completion_batch:
|
return completion
|
||||||
return [completion for completion, _, _ in completion_batch]
|
|
||||||
|
|
||||||
return [completion]
|
|
||||||
|
|
|
||||||
|
|
@ -158,15 +158,18 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
f"Empty context was provided to the completion for the query: {batched_query}"
|
f"Empty context was provided to the completion for the query: {batched_query}"
|
||||||
)
|
)
|
||||||
entity_nodes_batch = []
|
entity_nodes_batch = []
|
||||||
|
|
||||||
for batched_triplets in triplets:
|
for batched_triplets in triplets:
|
||||||
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
||||||
|
|
||||||
await asyncio.gather(
|
# Remove duplicates and update node access, if it is enabled
|
||||||
*[
|
for batched_entity_nodes in entity_nodes_batch:
|
||||||
update_node_access_timestamps(batched_entity_nodes)
|
# from itertools import chain
|
||||||
for batched_entity_nodes in entity_nodes_batch
|
#
|
||||||
]
|
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
|
||||||
)
|
# entity_nodes = list(set(flattened_entity_nodes))
|
||||||
|
|
||||||
|
await update_node_access_timestamps(batched_entity_nodes)
|
||||||
else:
|
else:
|
||||||
if len(triplets) == 0:
|
if len(triplets) == 0:
|
||||||
logger.warning("Empty context was provided to the completion")
|
logger.warning("Empty context was provided to the completion")
|
||||||
|
|
@ -209,9 +212,9 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
- Any: A generated completion based on the query and context provided.
|
- Any: A generated completion based on the query and context provided.
|
||||||
"""
|
"""
|
||||||
query_validation = validate_queries(query, query_batch)
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
if not query_validation[0]:
|
if not is_query_valid:
|
||||||
raise ValueError(query_validation[1])
|
raise ValueError(msg)
|
||||||
|
|
||||||
triplets = context
|
triplets = context
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,9 +147,9 @@ async def brute_force_triplet_search(
|
||||||
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
||||||
In batch mode, they are stored as list-of-lists (one list per query).
|
In batch mode, they are stored as list-of-lists (one list per query).
|
||||||
"""
|
"""
|
||||||
query_validation = validate_queries(query, query_batch)
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
if not query_validation[0]:
|
if not is_query_valid:
|
||||||
raise ValueError(query_validation[1])
|
raise ValueError(msg)
|
||||||
|
|
||||||
if top_k <= 0:
|
if top_k <= 0:
|
||||||
raise ValueError("top_k must be a positive integer.")
|
raise ValueError("top_k must be a positive integer.")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import pytest
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
@ -80,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
|
|
@ -106,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
||||||
max_iter=1,
|
max_iter=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert completion == "Generated answer"
|
assert completion == ["Generated answer"]
|
||||||
assert context_text == "Resolved context"
|
assert context_text == ["Resolved context"]
|
||||||
assert len(triplets) >= 1
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -126,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -143,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
||||||
max_iter=1,
|
max_iter=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert completion == "Generated answer"
|
assert completion == ["Generated answer"]
|
||||||
assert context_text == "Resolved context"
|
assert context_text == ["Resolved context"]
|
||||||
assert len(triplets) >= 1
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -168,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
|
||||||
retriever,
|
retriever,
|
||||||
"get_context",
|
"get_context",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=[[mock_edge], [mock_edge2]],
|
side_effect=[[[mock_edge]], [[mock_edge2]]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||||
|
|
@ -200,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
|
||||||
max_iter=2,
|
max_iter=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert completion == "Generated answer"
|
assert completion == ["Generated answer"]
|
||||||
assert context_text == "Resolved context"
|
assert context_text == ["Resolved context"]
|
||||||
assert len(triplets) >= 1
|
assert len(triplets) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -227,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge):
|
||||||
max_iter=1,
|
max_iter=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert completion == "Generated answer"
|
assert completion == ["Generated answer"]
|
||||||
call_kwargs = mock_generate.call_args[1]
|
call_kwargs = mock_generate.call_args[1]
|
||||||
assert call_kwargs.get("conversation_history") == "Previous conversation"
|
assert call_kwargs.get("conversation_history") == "Previous conversation"
|
||||||
|
|
||||||
|
|
@ -259,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge):
|
||||||
max_iter=1,
|
max_iter=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(completion, TestModel)
|
assert isinstance(completion, list)
|
||||||
assert completion.answer == "Test answer"
|
assert isinstance(completion[0], TestModel)
|
||||||
|
assert completion[0].answer == "Test answer"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -285,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
||||||
max_iter=1,
|
max_iter=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert completion == "Generated answer"
|
assert completion == ["Generated answer"]
|
||||||
# Verify conversation_history was passed as None when empty
|
# Verify conversation_history was passed as None when empty
|
||||||
call_kwargs = mock_generate.call_args[1]
|
call_kwargs = mock_generate.call_args[1]
|
||||||
assert call_kwargs.get("conversation_history") is None
|
assert call_kwargs.get("conversation_history") is None
|
||||||
|
|
@ -306,7 +309,7 @@ async def test_get_completion_without_context(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -316,7 +319,7 @@ async def test_get_completion_without_context(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
|
|
@ -372,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge):
|
||||||
mock_config.caching = False
|
mock_config.caching = False
|
||||||
mock_cache_config.return_value = mock_config
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1)
|
||||||
|
|
||||||
assert isinstance(completion, list)
|
assert isinstance(completion, list)
|
||||||
assert len(completion) == 1
|
assert len(completion) == 1
|
||||||
|
|
@ -397,7 +400,7 @@ async def test_get_completion_with_session(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -463,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
|
|
@ -528,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -570,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -612,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
|
|
@ -703,7 +706,7 @@ async def test_get_completion_batch_queries_with_context(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||||
return_value="Generated answer",
|
return_value="Generated answer",
|
||||||
|
|
@ -749,7 +752,7 @@ async def test_get_completion_batch_queries_without_context(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -790,7 +793,7 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue