fix: fix tests

This commit is contained in:
Andrej Milicevic 2026-01-19 00:03:35 +01:00
parent b88e4242ad
commit d258b1d7af

View file

@ -58,11 +58,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
response_model: Type = str,
query_batch: Optional[List[str]] = None,
) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
@ -107,6 +107,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
if triplets_batch is None:
triplets_batch = await self.get_context(query_batch=query_batch)
if not triplets_batch:
return []
context_text = ""
context_text_batch = await asyncio.gather(
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
@ -119,8 +122,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
original_query_batch = query_batch
finished_queries_data = {}
for i, query in enumerate(query_batch):
finished_queries_data[query] = (triplets_batch[i], context_text_batch[i])
for i, batched_query in enumerate(query_batch):
if not triplets_batch[i]:
query_batch[i] = ""
else:
finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i])
while round_idx <= context_extension_rounds:
logger.info(