feat: enable batch queries on graph completion retrievers

This commit is contained in:
Andrej Milicevic 2026-01-15 13:41:19 +01:00
parent 2cdbc02b35
commit d180e2aeae
3 changed files with 161 additions and 52 deletions

View file

@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
context_extension_rounds=4, context_extension_rounds=4,
response_model: Type = str, response_model: Type = str,
@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
""" """
triplets = context triplets = context
if triplets is None: if query:
triplets = await self.get_context(query) # This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query]
query = None
if triplets:
triplets = [triplets]
context_text = await self.resolve_edges_to_text(triplets) if triplets is None:
triplets = await self.get_context(query, query_batch)
context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await self.resolve_edges_to_text(triplets)
round_idx = 1 round_idx = 1
# We will be removing queries, and their associated triplets and context, as we go
# through iterations, so we need to save their final states for the final generation.
original_query_batch = query_batch
saved_triplets = []
saved_context_texts = []
while round_idx <= context_extension_rounds: while round_idx <= context_extension_rounds:
prev_size = len(triplets)
logger.info( logger.info(
f"Context extension: round {round_idx} - generating next graph locational query." f"Context extension: round {round_idx} - generating next graph locational query."
) )
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_context(completion) query_batch = [query for query in query_batch if query]
triplets = list(set(triplets)) triplets = [triplet_element for triplet_element in triplets if triplet_element]
context_text = await self.resolve_edges_to_text(triplets) context_texts = [context_text for context_text in context_texts if context_text]
if len(query_batch) == 0:
num_triplets = len(triplets)
if num_triplets == prev_size:
logger.info( logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early." f"Context extension: round {round_idx} no new triplets found; stopping early."
) )
break break
prev_sizes = [len(triplets_element) for triplets_element in triplets]
completions = await asyncio.gather(
*[
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
for query, context in zip(query_batch, context_texts)
],
)
new_triplets = await self.get_context(query_batch=completions)
for i, (triplets_element, new_triplets_element) in enumerate(
zip(triplets, new_triplets)
):
triplets_element += new_triplets_element
triplets[i] = list(dict.fromkeys(triplets_element))
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
new_sizes = [len(triplets_element) for triplets_element in triplets]
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
):
if prev_size == new_size:
# In this case, we can stop trying to extend the context of this query
query_batch[i] = ""
saved_triplets.append(triplet_element)
triplets[i] = []
saved_context_texts.append(context_text)
context_texts[i] = ""
logger.info( logger.info(
f"Context extension: round {round_idx} - " f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets: {num_triplets}" f"number of unique retrieved triplets for each query : {new_sizes}"
) )
round_idx += 1 round_idx += 1
# Reset variables for the final generations. They contain the final state
# of triplets and contexts for each query, after all extension iterations.
query_batch = original_query_batch
context_texts = saved_context_texts
triplets = saved_triplets
# Check if we need to generate context summary for caching # Check if we need to generate context summary for caching
cache_config = CacheConfig() cache_config = CacheConfig()
user = session_user.get() user = session_user.get()
@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
), ),
) )
else: else:
completion = await generate_completion( completion = await asyncio.gather(
query=query, *[
context=context_text, generate_completion(
user_prompt_path=self.user_prompt_path, query=query,
system_prompt_path=self.system_prompt_path, context=context_text,
system_prompt=self.system_prompt, user_prompt_path=self.user_prompt_path,
response_model=response_model, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context_text in zip(query_batch, context_texts)
],
) )
if self.save_interaction and context_text and triplets and completion: if self.save_interaction and context_text and triplets and completion:

View file

@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
max_iter=4, max_iter=4,
@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
if session_save: if session_save:
conversation_history = await get_conversation_history(session_id=session_id) conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion( completion_results = []
query=query, if query_batch and len(query_batch) > 0:
context=context, completion_results = await asyncio.gather(
conversation_history=conversation_history, *[
max_iter=max_iter, self._run_cot_completion(
response_model=response_model, query=query,
) context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
for query in query_batch
]
)
else:
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
await self.save_qa( await self.save_qa(
@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id, session_id=session_id,
) )
if completion_results:
return [completion for completion, _, _ in completion_results]
return [completion] return [completion]

View file

@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
""" """
return await resolve_edges_to_text(retrieved_edges) return await resolve_edges_to_text(retrieved_edges)
async def get_triplets(self, query: str) -> List[Edge]: async def get_triplets(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
""" """
Retrieves relevant graph triplets based on a query string. Retrieves relevant graph triplets based on a query string.
@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
found_triplets = await brute_force_triplet_search( found_triplets = await brute_force_triplet_search(
query, query,
query_batch,
top_k=self.top_k, top_k=self.top_k,
collections=vector_index_collections or None, collections=vector_index_collections or None,
node_type=self.node_type, node_type=self.node_type,
@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return found_triplets return found_triplets
async def get_context(self, query: str) -> List[Edge]: async def get_context(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
""" """
Retrieves and resolves graph triplets into context based on a query. Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
logger.warning("Search attempt on an empty knowledge graph") logger.warning("Search attempt on an empty knowledge graph")
return [] return []
triplets = await self.get_triplets(query) triplets = await self.get_triplets(query, query_batch)
if len(triplets) == 0: if len(triplets) == 0:
logger.warning("Empty context was provided to the completion") logger.warning("Empty context was provided to the completion")
@ -158,8 +167,9 @@ class GraphCompletionRetriever(BaseGraphRetriever):
async def get_completion( async def get_completion(
self, self,
query: str, query: Optional[str] = None,
context: Optional[List[Edge]] = None, query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
response_model: Type = str, response_model: Type = str,
) -> List[Any]: ) -> List[Any]:
@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
triplets = context triplets = context
if triplets is None: if triplets is None:
triplets = await self.get_context(query) triplets = await self.get_context(query, query_batch)
context_text = await resolve_edges_to_text(triplets) context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await resolve_edges_to_text(triplets)
cache_config = CacheConfig() cache_config = CacheConfig()
user = session_user.get() user = session_user.get()
@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
), ),
) )
else: else:
completion = await generate_completion( if query_batch and len(query_batch) > 0:
query=query, completion = await asyncio.gather(
context=context_text, *[
user_prompt_path=self.user_prompt_path, generate_completion(
system_prompt_path=self.system_prompt_path, query=query,
system_prompt=self.system_prompt, context=context,
response_model=response_model, user_prompt_path=self.user_prompt_path,
) system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context in zip(query_batch, context_texts)
],
)
else:
completion = await generate_completion(
query=query,
context=context_text if context_text else context_texts,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
await self.save_qa( await self.save_qa(