feat: adds conversation history to context if caching is enabled
This commit is contained in:
parent
48c832bf5f
commit
9e9489c858
3 changed files with 94 additions and 3 deletions
|
|
@ -4,14 +4,16 @@ from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
from cognee.modules.retrieval.utils.session_cache import (
|
||||||
|
save_to_session_cache,
|
||||||
|
get_conversation_history,
|
||||||
|
)
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||||
|
|
@ -168,6 +170,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
session_save = user_id and cache_config.caching
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
context_summary, completion = await asyncio.gather(
|
context_summary, completion = await asyncio.gather(
|
||||||
summarize_text(context_text),
|
summarize_text(context_text),
|
||||||
generate_completion(
|
generate_completion(
|
||||||
|
|
@ -176,6 +180,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
conversation_history=conversation_history,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,17 @@ async def generate_completion(
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
|
conversation_history: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = render_prompt(user_prompt_path, args)
|
user_prompt = render_prompt(user_prompt_path, args)
|
||||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||||
|
|
||||||
|
if conversation_history:
|
||||||
|
#:TODO: I would separate the history and put it into the system prompt but we have to test what works best with longer convos
|
||||||
|
system_prompt = conversation_history + "\nTASK:" + system_prompt
|
||||||
|
|
||||||
return await LLMGateway.acreate_structured_output(
|
return await LLMGateway.acreate_structured_output(
|
||||||
text_input=user_prompt,
|
text_input=user_prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict, Any
|
||||||
from cognee.context_global_variables import session_user
|
from cognee.context_global_variables import session_user
|
||||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
||||||
|
|
@ -73,3 +73,84 @@ async def save_to_session_cache(
|
||||||
f"Unexpected error saving to session cache: {type(e).__name__}: {str(e)}. Continuing without caching."
|
f"Unexpected error saving to session cache: {type(e).__name__}: {str(e)}. Continuing without caching."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_conversation_history(
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieves conversation history from cache and formats it as text.
|
||||||
|
|
||||||
|
Returns formatted conversation history with time, question, context, and answer
|
||||||
|
for the last N Q&A pairs (N is determined by cache engine default).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: Formatted conversation history string, or empty string if no history or error.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
-------
|
||||||
|
|
||||||
|
Previous conversation:
|
||||||
|
|
||||||
|
[2024-01-15 10:30:45]
|
||||||
|
QUESTION: What is X?
|
||||||
|
CONTEXT: X is a concept...
|
||||||
|
ANSWER: X is...
|
||||||
|
|
||||||
|
[2024-01-15 10:31:20]
|
||||||
|
QUESTION: How does Y work?
|
||||||
|
CONTEXT: Y is related to...
|
||||||
|
ANSWER: Y works by...
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
|
||||||
|
if not (user_id and cache_config.caching):
|
||||||
|
logger.debug("Session caching disabled or user not authenticated")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if session_id is None:
|
||||||
|
session_id = "default_session"
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||||
|
|
||||||
|
cache_engine = get_cache_engine()
|
||||||
|
|
||||||
|
if cache_engine is None:
|
||||||
|
logger.warning("Cache engine not available, skipping conversation history retrieval")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
history_entries = await cache_engine.get_latest_qa(str(user_id), session_id)
|
||||||
|
|
||||||
|
if not history_entries:
|
||||||
|
logger.debug("No conversation history found")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
history_text = "Previous conversation:\n\n"
|
||||||
|
for entry in history_entries:
|
||||||
|
history_text += f"[{entry.get('time', 'Unknown time')}]\n"
|
||||||
|
history_text += f"QUESTION: {entry.get('question', '')}\n"
|
||||||
|
history_text += f"CONTEXT: {entry.get('context', '')}\n"
|
||||||
|
history_text += f"ANSWER: {entry.get('answer', '')}\n\n"
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(history_entries)} conversation history entries")
|
||||||
|
return history_text
|
||||||
|
|
||||||
|
except CacheConnectionError as e:
|
||||||
|
logger.warning(f"Cache unavailable, continuing without conversation history: {e.message}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected error retrieving conversation history: {type(e).__name__}: {str(e)}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue