diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 89e8b80a2..d005759c9 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,6 +104,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ + + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: raise ValueError(msg) @@ -160,7 +172,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): prev_sizes = [ len(batched_query_state.triplets) for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context ] completions = await asyncio.gather( @@ -203,7 +214,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): new_sizes = [ len(batched_query_state.triplets) for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context ] for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): @@ -218,12 +228,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - # Check if we need to generate context summary for caching - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - session_save = user_id and cache_config.caching - completion_batch = [] if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 5f61d26d0..d42dbda0f 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -318,9 +318,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - is_query_valid, msg = validate_queries(query, query_batch) - if not is_query_valid: - raise ValueError(msg) # Check if session saving is enabled cache_config = CacheConfig() @@ -328,6 +325,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) + # Load conversation history if enabled conversation_history = "" if session_save: diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 711c7ab40..d64267c10 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -163,13 +163,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) # Remove duplicates and update node access, if it is enabled - for batched_entity_nodes in entity_nodes_batch: - # from itertools import chain - # - # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) - # entity_nodes = list(set(flattened_entity_nodes)) + import os - await update_node_access_timestamps(batched_entity_nodes) + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true": + for batched_entity_nodes in entity_nodes_batch: + # from itertools import chain + # + # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) + # entity_nodes = list(set(flattened_entity_nodes)) + + await update_node_access_timestamps(batched_entity_nodes) else: if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -212,6 +215,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: raise ValueError(msg) @@ -230,11 +243,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): else: context_text = await resolve_edges_to_text(triplets) - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - session_save = user_id and cache_config.caching - if session_save: conversation_history = await get_conversation_history(session_id=session_id)