From 0aa64403c5b82cacef3bbdf60b8574a0e19683cb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 15 Oct 2025 17:51:47 +0200 Subject: [PATCH] feat: basic session behavior (only graph completion now just to save) --- .../retrieval/graph_completion_retriever.py | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 80815a42a..60609ba56 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Optional, Type, List from uuid import NAMESPACE_OID, uuid5 @@ -9,14 +10,18 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node from cognee.modules.retrieval.utils.models import CogneeUserInteraction from cognee.modules.engine.models.node_set import NodeSet from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig + logger = get_logger("GraphCompletionRetriever") +cache_config = CacheConfig() class GraphCompletionRetriever(BaseGraphRetriever): @@ -132,6 +137,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): self, query: str, context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, ) -> List[str]: """ Generates a completion using graph connections context based on a query. @@ -155,19 +161,44 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = await resolve_edges_to_text(triplets) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) + # Check if we need to generate context summary for caching + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if session_save: + context_summary, completion = await asyncio.gather( + summarize_text(context_text), + generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + ) + else: + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( question=query, answer=completion, context=context_text, triplets=triplets ) + if session_save: + from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine + cache_engine = get_cache_engine() + if session_id is None: + session_id = 'default_session' + await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion) + + return [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: