fix: fix logic issue that coderabbit flagged regarding ordering of lists
This commit is contained in:
parent
98e8d226eb
commit
17554466ba
2 changed files with 25 additions and 23 deletions
|
|
@ -102,21 +102,20 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
triplets = await self.get_context(query_batch=query_batch)
|
triplets = await self.get_context(query_batch=query_batch)
|
||||||
|
|
||||||
context_text = ""
|
context_text = ""
|
||||||
context_texts = ""
|
context_texts = await asyncio.gather(
|
||||||
if triplets and isinstance(triplets[0], list):
|
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
context_texts = await asyncio.gather(
|
)
|
||||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
|
||||||
|
|
||||||
round_idx = 1
|
round_idx = 1
|
||||||
|
|
||||||
# We will be removing queries, and their associated triplets and context, as we go
|
# We will be removing queries, and their associated triplets and context, as we go
|
||||||
# through iterations, so we need to save their final states for the final generation.
|
# through iterations, so we need to save their final states for the final generation.
|
||||||
|
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
||||||
original_query_batch = query_batch
|
original_query_batch = query_batch
|
||||||
saved_triplets = []
|
finished_queries_data = {}
|
||||||
saved_context_texts = []
|
for i, query in enumerate(query_batch):
|
||||||
|
finished_queries_data[query] = (triplets[i], context_texts[i])
|
||||||
|
|
||||||
while round_idx <= context_extension_rounds:
|
while round_idx <= context_extension_rounds:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
|
|
@ -147,6 +146,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
||||||
new_triplets = await self.get_context(query_batch=completions)
|
new_triplets = await self.get_context(query_batch=completions)
|
||||||
for i, (triplets_element, new_triplets_element) in enumerate(
|
for i, (triplets_element, new_triplets_element) in enumerate(
|
||||||
zip(triplets, new_triplets)
|
zip(triplets, new_triplets)
|
||||||
|
|
@ -160,15 +160,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||||
|
|
||||||
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
|
for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate(
|
||||||
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
||||||
):
|
):
|
||||||
|
finished_queries_data[query] = (triplets_element, context_text)
|
||||||
if prev_size == new_size:
|
if prev_size == new_size:
|
||||||
# In this case, we can stop trying to extend the context of this query
|
# In this case, we can stop trying to extend the context of this query
|
||||||
query_batch[i] = ""
|
query_batch[i] = ""
|
||||||
saved_triplets.append(triplet_element)
|
|
||||||
triplets[i] = []
|
triplets[i] = []
|
||||||
saved_context_texts.append(context_text)
|
|
||||||
context_texts[i] = ""
|
context_texts[i] = ""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -181,12 +180,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
# Reset variables for the final generations. They contain the final state
|
# Reset variables for the final generations. They contain the final state
|
||||||
# of triplets and contexts for each query, after all extension iterations.
|
# of triplets and contexts for each query, after all extension iterations.
|
||||||
query_batch = original_query_batch
|
query_batch = original_query_batch
|
||||||
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
|
triplets = []
|
||||||
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
|
context_texts = []
|
||||||
|
for query in query_batch:
|
||||||
if len(query_batch) == 1:
|
triplets.append(finished_queries_data[query][0])
|
||||||
triplets = [] if not triplets else triplets[0]
|
context_texts.append(finished_queries_data[query][1])
|
||||||
context_text = "" if not context_texts else context_texts[0]
|
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
# Check if we need to generate context summary for caching
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
|
|
@ -224,11 +222,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context_text and triplets and completion:
|
# TODO: Do batch queries for save interaction
|
||||||
if isinstance(completion, list):
|
if self.save_interaction and context_texts and triplets and completion:
|
||||||
completion = completion[0]
|
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion[0], context=context_text, triplets=triplets
|
question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
|
||||||
|
|
@ -569,7 +569,12 @@ async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts
|
side_effect=[
|
||||||
|
"Resolved context",
|
||||||
|
"Resolved context",
|
||||||
|
"Extended context",
|
||||||
|
"Extended context",
|
||||||
|
], # Different contexts
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue