fix: raise error for query batch plus sessions and save interaction

This commit is contained in:
Andrej Milicevic 2026-01-20 11:36:01 +01:00
parent 05b5add480
commit 77fd18a60c
3 changed files with 40 additions and 22 deletions

View file

@ -104,6 +104,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer based on the query and the
extended context.
"""
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
@ -160,7 +172,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
prev_sizes = [
len(batched_query_state.triplets)
for batched_query_state in finished_queries_states.values()
if not batched_query_state.finished_extending_context
]
completions = await asyncio.gather(
@ -203,7 +214,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
new_sizes = [
len(batched_query_state.triplets)
for batched_query_state in finished_queries_states.values()
if not batched_query_state.finished_extending_context
]
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
@ -218,12 +228,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
completion_batch = []
if session_save:

View file

@ -318,9 +318,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
# Check if session saving is enabled
cache_config = CacheConfig()
@ -328,6 +325,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
# Load conversation history if enabled
conversation_history = ""
if session_save:

View file

@ -163,13 +163,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets))
# Remove duplicates and update node access, if it is enabled
for batched_entity_nodes in entity_nodes_batch:
# from itertools import chain
#
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
# entity_nodes = list(set(flattened_entity_nodes))
import os
await update_node_access_timestamps(batched_entity_nodes)
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true":
for batched_entity_nodes in entity_nodes_batch:
# from itertools import chain
#
# flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch))
# entity_nodes = list(set(flattened_entity_nodes))
await update_node_access_timestamps(batched_entity_nodes)
else:
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
@ -212,6 +215,16 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- Any: A generated completion based on the query and context provided.
"""
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
@ -230,11 +243,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
else:
context_text = await resolve_edges_to_text(triplets)
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)