fix: raise error for query batch plus sessions and save interaction
This commit is contained in:
parent
05b5add480
commit
77fd18a60c
3 changed files with 40 additions and 22 deletions
|
|
@ -104,6 +104,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
- List[str]: A list containing the generated answer based on the query and the
|
- List[str]: A list containing the generated answer based on the query and the
|
||||||
extended context.
|
extended context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check if we need to generate context summary for caching
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
if query_batch and session_save:
|
||||||
|
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||||
|
if query_batch and self.save_interaction:
|
||||||
|
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||||
|
|
||||||
is_query_valid, msg = validate_queries(query, query_batch)
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
if not is_query_valid:
|
if not is_query_valid:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
@ -160,7 +172,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
prev_sizes = [
|
prev_sizes = [
|
||||||
len(batched_query_state.triplets)
|
len(batched_query_state.triplets)
|
||||||
for batched_query_state in finished_queries_states.values()
|
for batched_query_state in finished_queries_states.values()
|
||||||
if not batched_query_state.finished_extending_context
|
|
||||||
]
|
]
|
||||||
|
|
||||||
completions = await asyncio.gather(
|
completions = await asyncio.gather(
|
||||||
|
|
@ -203,7 +214,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
new_sizes = [
|
new_sizes = [
|
||||||
len(batched_query_state.triplets)
|
len(batched_query_state.triplets)
|
||||||
for batched_query_state in finished_queries_states.values()
|
for batched_query_state in finished_queries_states.values()
|
||||||
if not batched_query_state.finished_extending_context
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
|
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
|
||||||
|
|
@ -218,12 +228,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
round_idx += 1
|
round_idx += 1
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
|
||||||
cache_config = CacheConfig()
|
|
||||||
user = session_user.get()
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
session_save = user_id and cache_config.caching
|
|
||||||
|
|
||||||
completion_batch = []
|
completion_batch = []
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
|
||||||
|
|
@ -318,9 +318,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
is_query_valid, msg = validate_queries(query, query_batch)
|
|
||||||
if not is_query_valid:
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Check if session saving is enabled
|
# Check if session saving is enabled
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
|
|
@ -328,6 +325,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
user_id = getattr(user, "id", None)
|
user_id = getattr(user, "id", None)
|
||||||
session_save = user_id and cache_config.caching
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
if query_batch and session_save:
|
||||||
|
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||||
|
if query_batch and self.save_interaction:
|
||||||
|
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||||
|
|
||||||
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
|
if not is_query_valid:
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Load conversation history if enabled
|
# Load conversation history if enabled
|
||||||
conversation_history = ""
|
conversation_history = ""
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
|
||||||
|
|
@ -163,13 +163,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
|
||||||
|
|
||||||
# Remove duplicates and update node access, if it is enabled
|
# Remove duplicates and update node access, if it is enabled
|
||||||
for batched_entity_nodes in entity_nodes_batch:
|
import os
|
||||||
# from itertools import chain
|
|
||||||
#
|
|
||||||
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
|
|
||||||
# entity_nodes = list(set(flattened_entity_nodes))
|
|
||||||
|
|
||||||
await update_node_access_timestamps(batched_entity_nodes)
|
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true":
|
||||||
|
for batched_entity_nodes in entity_nodes_batch:
|
||||||
|
# from itertools import chain
|
||||||
|
#
|
||||||
|
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
|
||||||
|
# entity_nodes = list(set(flattened_entity_nodes))
|
||||||
|
|
||||||
|
await update_node_access_timestamps(batched_entity_nodes)
|
||||||
else:
|
else:
|
||||||
if len(triplets) == 0:
|
if len(triplets) == 0:
|
||||||
logger.warning("Empty context was provided to the completion")
|
logger.warning("Empty context was provided to the completion")
|
||||||
|
|
@ -212,6 +215,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
- Any: A generated completion based on the query and context provided.
|
- Any: A generated completion based on the query and context provided.
|
||||||
"""
|
"""
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
if query_batch and session_save:
|
||||||
|
raise ValueError("You cannot use batch queries with session saving currently.")
|
||||||
|
if query_batch and self.save_interaction:
|
||||||
|
raise ValueError("Cannot use batch queries with interaction saving currently.")
|
||||||
|
|
||||||
is_query_valid, msg = validate_queries(query, query_batch)
|
is_query_valid, msg = validate_queries(query, query_batch)
|
||||||
if not is_query_valid:
|
if not is_query_valid:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
@ -230,11 +243,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
else:
|
else:
|
||||||
context_text = await resolve_edges_to_text(triplets)
|
context_text = await resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
cache_config = CacheConfig()
|
|
||||||
user = session_user.get()
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
session_save = user_id and cache_config.caching
|
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue