feat: enable batch queries on graph completion retrievers
This commit is contained in:
parent
2cdbc02b35
commit
d180e2aeae
3 changed files with 161 additions and 52 deletions
|
|
@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
response_model: Type = str,
|
||||
|
|
@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
"""
|
||||
triplets = context
|
||||
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query)
|
||||
if query:
|
||||
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
||||
query_batch = [query]
|
||||
query = None
|
||||
if triplets:
|
||||
triplets = [triplets]
|
||||
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query, query_batch)
|
||||
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
if isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
else:
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
round_idx = 1
|
||||
|
||||
# We will be removing queries, and their associated triplets and context, as we go
|
||||
# through iterations, so we need to save their final states for the final generation.
|
||||
original_query_batch = query_batch
|
||||
saved_triplets = []
|
||||
saved_context_texts = []
|
||||
while round_idx <= context_extension_rounds:
|
||||
prev_size = len(triplets)
|
||||
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||
)
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
triplets += await self.get_context(completion)
|
||||
triplets = list(set(triplets))
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
num_triplets = len(triplets)
|
||||
|
||||
if num_triplets == prev_size:
|
||||
query_batch = [query for query in query_batch if query]
|
||||
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
||||
context_texts = [context_text for context_text in context_texts if context_text]
|
||||
if len(query_batch) == 0:
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||
)
|
||||
break
|
||||
|
||||
prev_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||
|
||||
completions = await asyncio.gather(
|
||||
*[
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
for query, context in zip(query_batch, context_texts)
|
||||
],
|
||||
)
|
||||
|
||||
new_triplets = await self.get_context(query_batch=completions)
|
||||
for i, (triplets_element, new_triplets_element) in enumerate(
|
||||
zip(triplets, new_triplets)
|
||||
):
|
||||
triplets_element += new_triplets_element
|
||||
triplets[i] = list(dict.fromkeys(triplets_element))
|
||||
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
|
||||
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||
|
||||
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
|
||||
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
||||
):
|
||||
if prev_size == new_size:
|
||||
# In this case, we can stop trying to extend the context of this query
|
||||
query_batch[i] = ""
|
||||
saved_triplets.append(triplet_element)
|
||||
triplets[i] = []
|
||||
saved_context_texts.append(context_text)
|
||||
context_texts[i] = ""
|
||||
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} - "
|
||||
f"number of unique retrieved triplets: {num_triplets}"
|
||||
f"number of unique retrieved triplets for each query : {new_sizes}"
|
||||
)
|
||||
|
||||
round_idx += 1
|
||||
|
||||
# Reset variables for the final generations. They contain the final state
|
||||
# of triplets and contexts for each query, after all extension iterations.
|
||||
query_batch = original_query_batch
|
||||
context_texts = saved_context_texts
|
||||
triplets = saved_triplets
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
|
|
@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
completion = await asyncio.gather(
|
||||
*[
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query, context_text in zip(query_batch, context_texts)
|
||||
],
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
|
|
@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
completion, context_text, triplets = await self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
completion_results = []
|
||||
if query_batch and len(query_batch) > 0:
|
||||
completion_results = await asyncio.gather(
|
||||
*[
|
||||
self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query in query_batch
|
||||
]
|
||||
)
|
||||
else:
|
||||
completion, context_text, triplets = await self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
|
|
@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
if completion_results:
|
||||
return [completion for completion, _, _ in completion_results]
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
"""
|
||||
return await resolve_edges_to_text(retrieved_edges)
|
||||
|
||||
async def get_triplets(self, query: str) -> List[Edge]:
|
||||
async def get_triplets(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
) -> List[Edge] | List[List[Edge]]:
|
||||
"""
|
||||
Retrieves relevant graph triplets based on a query string.
|
||||
|
||||
|
|
@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query,
|
||||
query_batch,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
|
|
@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> List[Edge]:
|
||||
async def get_context(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
) -> List[Edge] | List[List[Edge]]:
|
||||
"""
|
||||
Retrieves and resolves graph triplets into context based on a query.
|
||||
|
||||
|
|
@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
return []
|
||||
|
||||
triplets = await self.get_triplets(query)
|
||||
triplets = await self.get_triplets(query, query_batch)
|
||||
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
|
|
@ -158,8 +167,9 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
|
|
@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
triplets = context
|
||||
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query)
|
||||
triplets = await self.get_context(query, query_batch)
|
||||
|
||||
context_text = await resolve_edges_to_text(triplets)
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
if isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
else:
|
||||
context_text = await resolve_edges_to_text(triplets)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
|
|
@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
if query_batch and len(query_batch) > 0:
|
||||
completion = await asyncio.gather(
|
||||
*[
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
for query, context in zip(query_batch, context_texts)
|
||||
],
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text if context_text else context_texts,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue