feat: enable batch queries on graph completion retrievers
This commit is contained in:
parent
2cdbc02b35
commit
d180e2aeae
3 changed files with 161 additions and 52 deletions
|
|
@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: Optional[str] = None,
|
||||||
context: Optional[List[Edge]] = None,
|
query_batch: Optional[List[str]] = None,
|
||||||
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
context_extension_rounds=4,
|
context_extension_rounds=4,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
|
|
@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
"""
|
"""
|
||||||
triplets = context
|
triplets = context
|
||||||
|
|
||||||
if triplets is None:
|
if query:
|
||||||
triplets = await self.get_context(query)
|
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
||||||
|
query_batch = [query]
|
||||||
|
query = None
|
||||||
|
if triplets:
|
||||||
|
triplets = [triplets]
|
||||||
|
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
if triplets is None:
|
||||||
|
triplets = await self.get_context(query, query_batch)
|
||||||
|
|
||||||
|
context_text = ""
|
||||||
|
context_texts = ""
|
||||||
|
if isinstance(triplets[0], list):
|
||||||
|
context_texts = await asyncio.gather(
|
||||||
|
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
round_idx = 1
|
round_idx = 1
|
||||||
|
|
||||||
|
# We will be removing queries, and their associated triplets and context, as we go
|
||||||
|
# through iterations, so we need to save their final states for the final generation.
|
||||||
|
original_query_batch = query_batch
|
||||||
|
saved_triplets = []
|
||||||
|
saved_context_texts = []
|
||||||
while round_idx <= context_extension_rounds:
|
while round_idx <= context_extension_rounds:
|
||||||
prev_size = len(triplets)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
)
|
)
|
||||||
completion = await generate_completion(
|
|
||||||
query=query,
|
|
||||||
context=context_text,
|
|
||||||
user_prompt_path=self.user_prompt_path,
|
|
||||||
system_prompt_path=self.system_prompt_path,
|
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
triplets += await self.get_context(completion)
|
query_batch = [query for query in query_batch if query]
|
||||||
triplets = list(set(triplets))
|
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
context_texts = [context_text for context_text in context_texts if context_text]
|
||||||
|
if len(query_batch) == 0:
|
||||||
num_triplets = len(triplets)
|
|
||||||
|
|
||||||
if num_triplets == prev_size:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
prev_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||||
|
|
||||||
|
completions = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
for query, context in zip(query_batch, context_texts)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_triplets = await self.get_context(query_batch=completions)
|
||||||
|
for i, (triplets_element, new_triplets_element) in enumerate(
|
||||||
|
zip(triplets, new_triplets)
|
||||||
|
):
|
||||||
|
triplets_element += new_triplets_element
|
||||||
|
triplets[i] = list(dict.fromkeys(triplets_element))
|
||||||
|
|
||||||
|
context_texts = await asyncio.gather(
|
||||||
|
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_sizes = [len(triplets_element) for triplets_element in triplets]
|
||||||
|
|
||||||
|
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
|
||||||
|
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
|
||||||
|
):
|
||||||
|
if prev_size == new_size:
|
||||||
|
# In this case, we can stop trying to extend the context of this query
|
||||||
|
query_batch[i] = ""
|
||||||
|
saved_triplets.append(triplet_element)
|
||||||
|
triplets[i] = []
|
||||||
|
saved_context_texts.append(context_text)
|
||||||
|
context_texts[i] = ""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context extension: round {round_idx} - "
|
f"Context extension: round {round_idx} - "
|
||||||
f"number of unique retrieved triplets: {num_triplets}"
|
f"number of unique retrieved triplets for each query : {new_sizes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
round_idx += 1
|
round_idx += 1
|
||||||
|
|
||||||
|
# Reset variables for the final generations. They contain the final state
|
||||||
|
# of triplets and contexts for each query, after all extension iterations.
|
||||||
|
query_batch = original_query_batch
|
||||||
|
context_texts = saved_context_texts
|
||||||
|
triplets = saved_triplets
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
# Check if we need to generate context summary for caching
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
user = session_user.get()
|
user = session_user.get()
|
||||||
|
|
@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = await generate_completion(
|
completion = await asyncio.gather(
|
||||||
query=query,
|
*[
|
||||||
context=context_text,
|
generate_completion(
|
||||||
user_prompt_path=self.user_prompt_path,
|
query=query,
|
||||||
system_prompt_path=self.system_prompt_path,
|
context=context_text,
|
||||||
system_prompt=self.system_prompt,
|
user_prompt_path=self.user_prompt_path,
|
||||||
response_model=response_model,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
for query, context_text in zip(query_batch, context_texts)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context_text and triplets and completion:
|
if self.save_interaction and context_text and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: Optional[str] = None,
|
||||||
|
query_batch: Optional[List[str]] = None,
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
max_iter=4,
|
max_iter=4,
|
||||||
|
|
@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
if session_save:
|
if session_save:
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
completion, context_text, triplets = await self._run_cot_completion(
|
completion_results = []
|
||||||
query=query,
|
if query_batch and len(query_batch) > 0:
|
||||||
context=context,
|
completion_results = await asyncio.gather(
|
||||||
conversation_history=conversation_history,
|
*[
|
||||||
max_iter=max_iter,
|
self._run_cot_completion(
|
||||||
response_model=response_model,
|
query=query,
|
||||||
)
|
context=context,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
max_iter=max_iter,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
for query in query_batch
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
max_iter=max_iter,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
|
|
@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if completion_results:
|
||||||
|
return [completion for completion, _, _ in completion_results]
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
"""
|
"""
|
||||||
return await resolve_edges_to_text(retrieved_edges)
|
return await resolve_edges_to_text(retrieved_edges)
|
||||||
|
|
||||||
async def get_triplets(self, query: str) -> List[Edge]:
|
async def get_triplets(
|
||||||
|
self,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
query_batch: Optional[List[str]] = None,
|
||||||
|
) -> List[Edge] | List[List[Edge]]:
|
||||||
"""
|
"""
|
||||||
Retrieves relevant graph triplets based on a query string.
|
Retrieves relevant graph triplets based on a query string.
|
||||||
|
|
||||||
|
|
@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
found_triplets = await brute_force_triplet_search(
|
found_triplets = await brute_force_triplet_search(
|
||||||
query,
|
query,
|
||||||
|
query_batch,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
collections=vector_index_collections or None,
|
collections=vector_index_collections or None,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
|
|
@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
return found_triplets
|
return found_triplets
|
||||||
|
|
||||||
async def get_context(self, query: str) -> List[Edge]:
|
async def get_context(
|
||||||
|
self,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
query_batch: Optional[List[str]] = None,
|
||||||
|
) -> List[Edge] | List[List[Edge]]:
|
||||||
"""
|
"""
|
||||||
Retrieves and resolves graph triplets into context based on a query.
|
Retrieves and resolves graph triplets into context based on a query.
|
||||||
|
|
||||||
|
|
@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
logger.warning("Search attempt on an empty knowledge graph")
|
logger.warning("Search attempt on an empty knowledge graph")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
triplets = await self.get_triplets(query)
|
triplets = await self.get_triplets(query, query_batch)
|
||||||
|
|
||||||
if len(triplets) == 0:
|
if len(triplets) == 0:
|
||||||
logger.warning("Empty context was provided to the completion")
|
logger.warning("Empty context was provided to the completion")
|
||||||
|
|
@ -158,8 +167,9 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: Optional[str] = None,
|
||||||
context: Optional[List[Edge]] = None,
|
query_batch: Optional[List[str]] = None,
|
||||||
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
|
|
@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
triplets = context
|
triplets = context
|
||||||
|
|
||||||
if triplets is None:
|
if triplets is None:
|
||||||
triplets = await self.get_context(query)
|
triplets = await self.get_context(query, query_batch)
|
||||||
|
|
||||||
context_text = await resolve_edges_to_text(triplets)
|
context_text = ""
|
||||||
|
context_texts = ""
|
||||||
|
if isinstance(triplets[0], list):
|
||||||
|
context_texts = await asyncio.gather(
|
||||||
|
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_text = await resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
user = session_user.get()
|
user = session_user.get()
|
||||||
|
|
@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = await generate_completion(
|
if query_batch and len(query_batch) > 0:
|
||||||
query=query,
|
completion = await asyncio.gather(
|
||||||
context=context_text,
|
*[
|
||||||
user_prompt_path=self.user_prompt_path,
|
generate_completion(
|
||||||
system_prompt_path=self.system_prompt_path,
|
query=query,
|
||||||
system_prompt=self.system_prompt,
|
context=context,
|
||||||
response_model=response_model,
|
user_prompt_path=self.user_prompt_path,
|
||||||
)
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
for query, context in zip(query_batch, context_texts)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text if context_text else context_texts,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue