feat: enable batch queries on graph completion retrievers

This commit is contained in:
Andrej Milicevic 2026-01-15 13:41:19 +01:00
parent 2cdbc02b35
commit d180e2aeae
3 changed files with 161 additions and 52 deletions

View file

@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
response_model: Type = str,
@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
"""
triplets = context
if triplets is None:
triplets = await self.get_context(query)
if query:
# This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query]
query = None
if triplets:
triplets = [triplets]
context_text = await self.resolve_edges_to_text(triplets)
if triplets is None:
triplets = await self.get_context(query, query_batch)
context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await self.resolve_edges_to_text(triplets)
round_idx = 1
# We will be removing queries, and their associated triplets and context, as we go
# through iterations, so we need to save their final states for the final generation.
original_query_batch = query_batch
saved_triplets = []
saved_context_texts = []
while round_idx <= context_extension_rounds:
prev_size = len(triplets)
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_context(completion)
triplets = list(set(triplets))
context_text = await self.resolve_edges_to_text(triplets)
num_triplets = len(triplets)
if num_triplets == prev_size:
query_batch = [query for query in query_batch if query]
triplets = [triplet_element for triplet_element in triplets if triplet_element]
context_texts = [context_text for context_text in context_texts if context_text]
if len(query_batch) == 0:
logger.info(
f"Context extension: round {round_idx} no new triplets found; stopping early."
)
break
prev_sizes = [len(triplets_element) for triplets_element in triplets]
completions = await asyncio.gather(
*[
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
for query, context in zip(query_batch, context_texts)
],
)
new_triplets = await self.get_context(query_batch=completions)
for i, (triplets_element, new_triplets_element) in enumerate(
zip(triplets, new_triplets)
):
triplets_element += new_triplets_element
triplets[i] = list(dict.fromkeys(triplets_element))
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
new_sizes = [len(triplets_element) for triplets_element in triplets]
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
):
if prev_size == new_size:
# In this case, we can stop trying to extend the context of this query
query_batch[i] = ""
saved_triplets.append(triplet_element)
triplets[i] = []
saved_context_texts.append(context_text)
context_texts[i] = ""
logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets: {num_triplets}"
f"number of unique retrieved triplets for each query : {new_sizes}"
)
round_idx += 1
# Reset variables for the final generations. They contain the final state
# of triplets and contexts for each query, after all extension iterations.
query_batch = original_query_batch
context_texts = saved_context_texts
triplets = saved_triplets
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
completion = await asyncio.gather(
*[
generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context_text in zip(query_batch, context_texts)
],
)
if self.save_interaction and context_text and triplets and completion:

View file

@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: str,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
completion_results = []
if query_batch and len(query_batch) > 0:
completion_results = await asyncio.gather(
*[
self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
for query in query_batch
]
)
else:
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
if completion_results:
return [completion for completion, _, _ in completion_results]
return [completion]

View file

@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
"""
return await resolve_edges_to_text(retrieved_edges)
async def get_triplets(self, query: str) -> List[Edge]:
async def get_triplets(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
"""
Retrieves relevant graph triplets based on a query string.
@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
found_triplets = await brute_force_triplet_search(
query,
query_batch,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return found_triplets
async def get_context(self, query: str) -> List[Edge]:
async def get_context(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
) -> List[Edge] | List[List[Edge]]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
logger.warning("Search attempt on an empty knowledge graph")
return []
triplets = await self.get_triplets(query)
triplets = await self.get_triplets(query, query_batch)
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
@ -158,8 +167,9 @@ class GraphCompletionRetriever(BaseGraphRetriever):
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
context: Optional[List[Edge] | List[List[Edge]]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
triplets = context
if triplets is None:
triplets = await self.get_context(query)
triplets = await self.get_context(query, query_batch)
context_text = await resolve_edges_to_text(triplets)
context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await resolve_edges_to_text(triplets)
cache_config = CacheConfig()
user = session_user.get()
@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if query_batch and len(query_batch) > 0:
completion = await asyncio.gather(
*[
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
for query, context in zip(query_batch, context_texts)
],
)
else:
completion = await generate_completion(
query=query,
context=context_text if context_text else context_texts,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(