fix: PR comments fixes
This commit is contained in:
parent
ab7b5d5445
commit
b88e4242ad
5 changed files with 116 additions and 68 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Optional, List, Type, Any
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.validate_queries import validate_queries
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
|
|
@ -90,20 +91,25 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
- List[str]: A list containing the generated answer based on the query and the
|
||||
extended context.
|
||||
"""
|
||||
triplets = context
|
||||
# TODO: This may be unnecessary in this retriever, will check later
|
||||
query_validation = validate_queries(query, query_batch)
|
||||
if not query_validation[0]:
|
||||
raise ValueError(query_validation[1])
|
||||
|
||||
triplets_batch = context
|
||||
|
||||
if query:
|
||||
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
||||
query_batch = [query]
|
||||
if triplets:
|
||||
triplets = [triplets]
|
||||
if triplets_batch:
|
||||
triplets_batch = [triplets_batch]
|
||||
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query_batch=query_batch)
|
||||
if triplets_batch is None:
|
||||
triplets_batch = await self.get_context(query_batch=query_batch)
|
||||
|
||||
context_text = ""
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
||||
)
|
||||
|
||||
round_idx = 1
|
||||
|
|
@ -114,7 +120,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
original_query_batch = query_batch
|
||||
finished_queries_data = {}
|
||||
for i, query in enumerate(query_batch):
|
||||
finished_queries_data[query] = (triplets[i], context_texts[i])
|
||||
finished_queries_data[query] = (triplets_batch[i], context_text_batch[i])
|
||||
|
||||
while round_idx <= context_extension_rounds:
|
||||
logger.info(
|
||||
|
|
@ -123,15 +129,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
# Filter out the queries that cannot be extended further, and their associated contexts
|
||||
query_batch = [query for query in query_batch if query]
|
||||
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
||||
context_texts = [context_text for context_text in context_texts if context_text]
|
||||
triplets_batch = [triplets for triplets in triplets_batch if triplets]
|
||||
context_text_batch = [
|
||||
context_text for context_text in context_text_batch if context_text
|
||||
]
|
||||
if len(query_batch) == 0:
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||
)
|
||||
break
|
||||
|
||||
prev_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||
prev_sizes = [len(triplets) for triplets in triplets_batch]
|
||||
|
||||
completions = await asyncio.gather(
|
||||
*[
|
||||
|
|
@ -142,33 +150,31 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
for query, context in zip(query_batch, context_texts)
|
||||
for query, context in zip(query_batch, context_text_batch)
|
||||
],
|
||||
)
|
||||
|
||||
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
||||
new_triplets = await self.get_context(query_batch=completions)
|
||||
for i, (triplets_element, new_triplets_element) in enumerate(
|
||||
zip(triplets, new_triplets)
|
||||
):
|
||||
triplets_element += new_triplets_element
|
||||
triplets[i] = list(dict.fromkeys(triplets_element))
|
||||
new_triplets_batch = await self.get_context(query_batch=completions)
|
||||
for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)):
|
||||
triplets += new_triplets
|
||||
triplets_batch[i] = list(dict.fromkeys(triplets))
|
||||
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
||||
)
|
||||
|
||||
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||
new_sizes = [len(triplets) for triplets in triplets_batch]
|
||||
|
||||
for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate(
|
||||
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
||||
for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate(
|
||||
zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch)
|
||||
):
|
||||
finished_queries_data[query] = (triplets_element, context_text)
|
||||
finished_queries_data[query] = (triplets, context_text)
|
||||
if prev_size == new_size:
|
||||
# In this case, we can stop trying to extend the context of this query
|
||||
query_batch[i] = ""
|
||||
triplets[i] = []
|
||||
context_texts[i] = ""
|
||||
triplets_batch[i] = []
|
||||
context_text_batch[i] = ""
|
||||
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} - "
|
||||
|
|
@ -180,11 +186,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
# Reset variables for the final generations. They contain the final state
|
||||
# of triplets and contexts for each query, after all extension iterations.
|
||||
query_batch = original_query_batch
|
||||
triplets = []
|
||||
context_texts = []
|
||||
for query in query_batch:
|
||||
triplets.append(finished_queries_data[query][0])
|
||||
context_texts.append(finished_queries_data[query][1])
|
||||
triplets_batch = []
|
||||
context_text_batch = []
|
||||
for batched_query in query_batch:
|
||||
triplets_batch.append(finished_queries_data[batched_query][0])
|
||||
context_text_batch.append(finished_queries_data[batched_query][1])
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -192,6 +198,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
completion_batch = []
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
|
|
@ -208,24 +216,27 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
),
|
||||
)
|
||||
else:
|
||||
completion = await asyncio.gather(
|
||||
completion_batch = await asyncio.gather(
|
||||
*[
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
query=batched_query,
|
||||
context=batched_context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query, context_text in zip(query_batch, context_texts)
|
||||
for batched_query, batched_context_text in zip(query_batch, context_text_batch)
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Do batch queries for save interaction
|
||||
if self.save_interaction and context_texts and triplets and completion:
|
||||
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0]
|
||||
question=query,
|
||||
answer=completion_batch[0],
|
||||
context=context_text_batch[0],
|
||||
triplets=triplets_batch[0],
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
@ -236,4 +247,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion if isinstance(completion, list) else [completion]
|
||||
return completion_batch if completion_batch else [completion]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
from typing import Optional, List, Type, Any
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.validate_queries import validate_queries
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -203,6 +204,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
query_validation = validate_queries(query, query_batch)
|
||||
if not query_validation[0]:
|
||||
raise ValueError(query_validation[1])
|
||||
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
|
|
@ -214,24 +219,25 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
completion_results = []
|
||||
context_batch = context
|
||||
completion_batch = []
|
||||
if query_batch and len(query_batch) > 0:
|
||||
if not context:
|
||||
if not context_batch:
|
||||
# Having a list is necessary to zip through it
|
||||
context = []
|
||||
for query in query_batch:
|
||||
context.append(None)
|
||||
context_batch = []
|
||||
for _ in query_batch:
|
||||
context_batch.append(None)
|
||||
|
||||
completion_results = await asyncio.gather(
|
||||
completion_batch = await asyncio.gather(
|
||||
*[
|
||||
self._run_cot_completion(
|
||||
query=query,
|
||||
context=context_el,
|
||||
context=context,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query, context_el in zip(query_batch, context)
|
||||
for batched_query, context in zip(query_batch, context_batch)
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -260,7 +266,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
if completion_results:
|
||||
return [completion for completion, _, _ in completion_results]
|
||||
if completion_batch:
|
||||
return [completion for completion, _, _ in completion_batch]
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5
|
|||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.validate_queries import validate_queries
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
|
|
@ -150,15 +151,33 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
triplets = await self.get_triplets(query, query_batch)
|
||||
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
return []
|
||||
if query_batch:
|
||||
for batched_triplets, batched_query in zip(triplets, query_batch):
|
||||
if len(batched_triplets) == 0:
|
||||
logger.warning(
|
||||
f"Empty context was provided to the completion for the query: {batched_query}"
|
||||
)
|
||||
entity_nodes_batch = []
|
||||
for batched_triplets in triplets:
|
||||
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
||||
|
||||
# context = await self.resolve_edges_to_text(triplets)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
update_node_access_timestamps(batched_entity_nodes)
|
||||
for batched_entity_nodes in entity_nodes_batch
|
||||
]
|
||||
)
|
||||
else:
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
return []
|
||||
|
||||
entity_nodes = get_entity_nodes_from_triplets(triplets)
|
||||
# context = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
entity_nodes = get_entity_nodes_from_triplets(triplets)
|
||||
|
||||
await update_node_access_timestamps(entity_nodes)
|
||||
|
||||
await update_node_access_timestamps(entity_nodes)
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
|
|
@ -190,15 +209,19 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
- Any: A generated completion based on the query and context provided.
|
||||
"""
|
||||
query_validation = validate_queries(query, query_batch)
|
||||
if not query_validation[0]:
|
||||
raise ValueError(query_validation[1])
|
||||
|
||||
triplets = context
|
||||
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query, query_batch)
|
||||
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
context_text_batch = []
|
||||
if triplets and isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
else:
|
||||
|
|
@ -236,7 +259,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query, context in zip(query_batch, context_texts)
|
||||
for query, context in zip(query_batch, context_text_batch)
|
||||
],
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Optional, Type, Union
|
||||
|
||||
from cognee.modules.retrieval.utils.validate_queries import validate_queries
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -146,17 +147,10 @@ async def brute_force_triplet_search(
|
|||
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
||||
In batch mode, they are stored as list-of-lists (one list per query).
|
||||
"""
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if query is not None and (not query or not isinstance(query, str)):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
if query_batch is not None:
|
||||
if not isinstance(query_batch, list) or not query_batch:
|
||||
raise ValueError("query_batch must be a non-empty list of strings.")
|
||||
if not all(isinstance(q, str) and q for q in query_batch):
|
||||
raise ValueError("All items in query_batch must be non-empty strings.")
|
||||
query_validation = validate_queries(query, query_batch)
|
||||
if not query_validation[0]:
|
||||
raise ValueError(query_validation[1])
|
||||
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
|
|
|
|||
14
cognee/modules/retrieval/utils/validate_queries.py
Normal file
14
cognee/modules/retrieval/utils/validate_queries.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
def validate_queries(query, query_batch) -> tuple[bool, str]:
|
||||
if query is not None and query_batch is not None:
|
||||
return False, "Cannot provide both 'query' and 'query_batch'; use exactly one."
|
||||
if query is None and query_batch is None:
|
||||
return False, "Must provide either 'query' or 'query_batch'."
|
||||
if query is not None and (not query or not isinstance(query, str)):
|
||||
return False, "The query must be a non-empty string."
|
||||
if query_batch is not None:
|
||||
if not isinstance(query_batch, list) or not query_batch:
|
||||
return False, "query_batch must be a non-empty list of strings."
|
||||
if not all(isinstance(q, str) and q for q in query_batch):
|
||||
return False, "All items in query_batch must be non-empty strings."
|
||||
|
||||
return True, ""
|
||||
Loading…
Add table
Reference in a new issue