fix: fix logic issue that coderabbit flagged regarding ordering of lists
This commit is contained in:
parent
98e8d226eb
commit
17554466ba
2 changed files with 25 additions and 23 deletions
|
|
@ -102,21 +102,20 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
triplets = await self.get_context(query_batch=query_batch)
|
||||
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
if triplets and isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
else:
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
|
||||
round_idx = 1
|
||||
|
||||
# We will be removing queries, and their associated triplets and context, as we go
|
||||
# through iterations, so we need to save their final states for the final generation.
|
||||
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
||||
original_query_batch = query_batch
|
||||
saved_triplets = []
|
||||
saved_context_texts = []
|
||||
finished_queries_data = {}
|
||||
for i, query in enumerate(query_batch):
|
||||
finished_queries_data[query] = (triplets[i], context_texts[i])
|
||||
|
||||
while round_idx <= context_extension_rounds:
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||
|
|
@ -147,6 +146,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
],
|
||||
)
|
||||
|
||||
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
||||
new_triplets = await self.get_context(query_batch=completions)
|
||||
for i, (triplets_element, new_triplets_element) in enumerate(
|
||||
zip(triplets, new_triplets)
|
||||
|
|
@ -160,15 +160,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||
|
||||
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
|
||||
for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate(
|
||||
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
||||
):
|
||||
finished_queries_data[query] = (triplets_element, context_text)
|
||||
if prev_size == new_size:
|
||||
# In this case, we can stop trying to extend the context of this query
|
||||
query_batch[i] = ""
|
||||
saved_triplets.append(triplet_element)
|
||||
triplets[i] = []
|
||||
saved_context_texts.append(context_text)
|
||||
context_texts[i] = ""
|
||||
|
||||
logger.info(
|
||||
|
|
@ -181,12 +180,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
# Reset variables for the final generations. They contain the final state
|
||||
# of triplets and contexts for each query, after all extension iterations.
|
||||
query_batch = original_query_batch
|
||||
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
|
||||
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
|
||||
|
||||
if len(query_batch) == 1:
|
||||
triplets = [] if not triplets else triplets[0]
|
||||
context_text = "" if not context_texts else context_texts[0]
|
||||
triplets = []
|
||||
context_texts = []
|
||||
for query in query_batch:
|
||||
triplets.append(finished_queries_data[query][0])
|
||||
context_texts.append(finished_queries_data[query][1])
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -224,11 +222,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
],
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
if isinstance(completion, list):
|
||||
completion = completion[0]
|
||||
# TODO: Do batch queries for save interaction
|
||||
if self.save_interaction and context_texts and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion[0], context=context_text, triplets=triplets
|
||||
question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0]
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -569,7 +569,12 @@ async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
|
|||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts
|
||||
side_effect=[
|
||||
"Resolved context",
|
||||
"Resolved context",
|
||||
"Extended context",
|
||||
"Extended context",
|
||||
], # Different contexts
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue