fix: fix tests
This commit is contained in:
parent
b88e4242ad
commit
d258b1d7af
1 changed files with 9 additions and 3 deletions
|
|
@ -58,11 +58,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
query_batch: Optional[List[str]] = None,
|
|
||||||
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
context_extension_rounds=4,
|
context_extension_rounds=4,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
|
query_batch: Optional[List[str]] = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Extends the context for a given query by retrieving related triplets and generating new
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
|
|
@ -107,6 +107,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
if triplets_batch is None:
|
if triplets_batch is None:
|
||||||
triplets_batch = await self.get_context(query_batch=query_batch)
|
triplets_batch = await self.get_context(query_batch=query_batch)
|
||||||
|
|
||||||
|
if not triplets_batch:
|
||||||
|
return []
|
||||||
|
|
||||||
context_text = ""
|
context_text = ""
|
||||||
context_text_batch = await asyncio.gather(
|
context_text_batch = await asyncio.gather(
|
||||||
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
*[self.resolve_edges_to_text(triplets) for triplets in triplets_batch]
|
||||||
|
|
@ -119,8 +122,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
|
||||||
original_query_batch = query_batch
|
original_query_batch = query_batch
|
||||||
finished_queries_data = {}
|
finished_queries_data = {}
|
||||||
for i, query in enumerate(query_batch):
|
for i, batched_query in enumerate(query_batch):
|
||||||
finished_queries_data[query] = (triplets_batch[i], context_text_batch[i])
|
if not triplets_batch[i]:
|
||||||
|
query_batch[i] = ""
|
||||||
|
else:
|
||||||
|
finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i])
|
||||||
|
|
||||||
while round_idx <= context_extension_rounds:
|
while round_idx <= context_extension_rounds:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue