feat: adds conversation history to most of the retrievers
This commit is contained in:
parent
9e9489c858
commit
cdbdbfd755
5 changed files with 29 additions and 10 deletions
|
|
@ -6,7 +6,7 @@ from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtract
|
|||
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
|
|
@ -120,6 +120,8 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(str(context)),
|
||||
generate_completion(
|
||||
|
|
@ -127,6 +129,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Optional
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
|
@ -104,6 +104,8 @@ class CompletionRetriever(BaseRetriever):
|
|||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context),
|
||||
generate_completion(
|
||||
|
|
@ -112,6 +114,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
|
|
@ -129,6 +129,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context_text),
|
||||
generate_completion(
|
||||
|
|
@ -137,6 +139,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.context_global_variables import session_user
|
||||
|
|
@ -92,6 +92,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
followup_question = ""
|
||||
triplets = []
|
||||
completion = ""
|
||||
|
||||
# Retrieve conversation history if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
conversation_history = ""
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
|
|
@ -110,6 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history if session_save else None,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
|
|
@ -147,11 +158,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
|
||||
# Save to session cache
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
|
||||
if user_id and cache_config.caching:
|
||||
if session_save:
|
||||
context_summary = await summarize_text(context_text)
|
||||
await save_to_session_cache(
|
||||
query=query,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Optional, List, Type
|
|||
from operator import itemgetter
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache, get_conversation_history
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
|
@ -171,6 +171,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context),
|
||||
generate_completion(
|
||||
|
|
@ -178,6 +180,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue