From cdbdbfd7554f17c46cec0551ca981b845022e922 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:24:57 +0200 Subject: [PATCH] feat: adds conversation history to most of the retrievers --- .../retrieval/EntityCompletionRetriever.py | 5 ++++- .../modules/retrieval/completion_retriever.py | 5 ++++- ..._completion_context_extension_retriever.py | 5 ++++- .../graph_completion_cot_retriever.py | 19 +++++++++++++------ .../modules/retrieval/temporal_retriever.py | 5 ++++- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 759f18762..d9804647a 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -6,7 +6,7 @@ from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtract from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text -from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history from cognee.context_global_variables import session_user from cognee.infrastructure.databases.cache.config import CacheConfig @@ -120,6 +120,8 @@ class EntityCompletionRetriever(BaseRetriever): session_save = user_id and cache_config.caching if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + context_summary, completion = await asyncio.gather( summarize_text(str(context)), generate_completion( @@ -127,6 +129,7 @@ class EntityCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + conversation_history=conversation_history, ), ) else: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 4263a6f64..b6fc8a48d 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -4,7 +4,7 @@ from typing import Any, Optional from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text -from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError @@ -104,6 +104,8 @@ class CompletionRetriever(BaseRetriever): session_save = user_id and cache_config.caching if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + context_summary, completion = await asyncio.gather( summarize_text(context), generate_completion( @@ -112,6 +114,7 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + conversation_history=conversation_history, ), ) else: diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 79f6acbf3..a1f7821bf 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -4,7 +4,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text -from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history from cognee.context_global_variables import session_user from cognee.infrastructure.databases.cache.config import CacheConfig @@ -129,6 +129,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + context_summary, completion = await asyncio.gather( summarize_text(context_text), generate_completion( @@ -137,6 +139,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + conversation_history=conversation_history, ), ) else: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 2a44e3b15..e95cd717d 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -5,7 +5,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text -from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt from cognee.context_global_variables import session_user @@ -92,6 +92,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): followup_question = "" triplets = [] completion = "" + + # Retrieve conversation history if session saving is enabled + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + conversation_history = "" + if session_save: + conversation_history = await get_conversation_history(session_id=session_id) for round_idx in range(max_iter + 1): if round_idx == 0: @@ -110,6 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + conversation_history=conversation_history if session_save else None, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -147,11 +158,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ) # Save to session cache - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - - if user_id and cache_config.caching: + if session_save: context_summary = await summarize_text(context_text) await save_to_session_cache( query=query, diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 1367076e7..efd4b4749 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -6,7 +6,7 @@ from typing import Any, Optional, List, Type from operator import itemgetter from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text -from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.llm.prompts import render_prompt from cognee.infrastructure.llm import LLMGateway @@ -171,6 +171,8 @@ class TemporalRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + context_summary, completion = await asyncio.gather( summarize_text(context), generate_completion( @@ -178,6 +180,7 @@ class TemporalRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + conversation_history=conversation_history, ), ) else: