fix: fix tests
This commit is contained in:
parent
b88e4242ad
commit
d258b1d7af
1 changed files with 9 additions and 3 deletions
|
|
@ -58,11 +58,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
async def get_completion(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
response_model: Type = str,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
|
|
@ -107,6 +107,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
if triplets_batch is None:
|
||||
triplets_batch = await self.get_context(query_batch=query_batch)
|
||||
|
||||
if not triplets_batch:
|
||||
return []
|
||||
|
||||
context_text = ""
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
||||
|
|
@ -119,8 +122,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
||||
original_query_batch = query_batch
|
||||
finished_queries_data = {}
|
||||
for i, query in enumerate(query_batch):
|
||||
finished_queries_data[query] = (triplets_batch[i], context_text_batch[i])
|
||||
for i, batched_query in enumerate(query_batch):
|
||||
if not triplets_batch[i]:
|
||||
query_batch[i] = ""
|
||||
else:
|
||||
finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i])
|
||||
|
||||
while round_idx <= context_extension_rounds:
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue