fix: raise error for query batch plus sessions and save interaction
This commit is contained in:
parent
05b5add480
commit
77fd18a60c
3 changed files with 40 additions and 22 deletions
|
|
@ -104,6 +104,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
- List[str]: A list containing the generated answer based on the query and the
|
||||
extended context.
|
||||
"""
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if query_batch and session_save:
|
||||
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||
if query_batch and self.save_interaction:
|
||||
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
raise ValueError(msg)
|
||||
|
|
@ -160,7 +172,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
prev_sizes = [
|
||||
len(batched_query_state.triplets)
|
||||
for batched_query_state in finished_queries_states.values()
|
||||
if not batched_query_state.finished_extending_context
|
||||
]
|
||||
|
||||
completions = await asyncio.gather(
|
||||
|
|
@ -203,7 +214,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
new_sizes = [
|
||||
len(batched_query_state.triplets)
|
||||
for batched_query_state in finished_queries_states.values()
|
||||
if not batched_query_state.finished_extending_context
|
||||
]
|
||||
|
||||
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
|
||||
|
|
@ -218,12 +228,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
round_idx += 1
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
completion_batch = []
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -318,9 +318,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -328,6 +325,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if query_batch and session_save:
|
||||
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||
if query_batch and self.save_interaction:
|
||||
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
raise ValueError(msg)
|
||||
|
||||
# Load conversation history if enabled
|
||||
conversation_history = ""
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -163,13 +163,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
||||
|
||||
# Remove duplicates and update node access, if it is enabled
|
||||
for batched_entity_nodes in entity_nodes_batch:
|
||||
# from itertools import chain
|
||||
#
|
||||
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
|
||||
# entity_nodes = list(set(flattened_entity_nodes))
|
||||
import os
|
||||
|
||||
await update_node_access_timestamps(batched_entity_nodes)
|
||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true":
|
||||
for batched_entity_nodes in entity_nodes_batch:
|
||||
# from itertools import chain
|
||||
#
|
||||
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
|
||||
# entity_nodes = list(set(flattened_entity_nodes))
|
||||
|
||||
await update_node_access_timestamps(batched_entity_nodes)
|
||||
else:
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
|
|
@ -212,6 +215,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
- Any: A generated completion based on the query and context provided.
|
||||
"""
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if query_batch and session_save:
|
||||
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||
if query_batch and self.save_interaction:
|
||||
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
raise ValueError(msg)
|
||||
|
|
@ -230,11 +243,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
else:
|
||||
context_text = await resolve_edges_to_text(triplets)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue