diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index ed66038c2..759f18762 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -1,10 +1,14 @@ +import asyncio from typing import Any, Optional, List from cognee.shared.logging_utils import get_logger from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider from cognee.modules.retrieval.base_retriever import BaseRetriever -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger("entity_completion_retriever") @@ -109,12 +113,38 @@ class EntityCompletionRetriever(BaseRetriever): if context is None: return ["No relevant entities found for the query."] - completion = await generate_completion( - query=query, - context=context, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - ) + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if session_save: + context_summary, completion = await asyncio.gather( + summarize_text(str(context)), + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ), + ) + else: + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + if session_save: + await save_to_session_cache( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) + return [completion] except Exception as e: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index dcddccef8..4263a6f64 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,11 +1,15 @@ +import asyncio from typing import Any, Optional from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger("CompletionRetriever") @@ -93,11 +97,38 @@ class CompletionRetriever(BaseRetriever): if context is None: context = await self.get_context(query) - completion = await generate_completion( - query=query, - context=context, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if session_save: + context_summary, completion = await asyncio.gather( + summarize_text(context), + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ), + ) + else: + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + + if session_save: + await save_to_session_cache( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) + return completion diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 3398d3fc2..79f6acbf3 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,8 +1,12 @@ +import asyncio from typing import Optional, List, Type from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() @@ -118,17 +122,43 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if session_save: + context_summary, completion = await asyncio.gather( + summarize_text(context_text), + generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ), + ) + else: + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) if self.save_interaction and context_text and triplets and completion: await self.save_qa( question=query, answer=completion, context=context_text, triplets=triplets ) + if session_save: + await save_to_session_cache( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 715a1406c..aa221a9b0 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,11 +1,15 @@ +import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() @@ -142,4 +146,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) + # Save to session cache + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + + if user_id and cache_config.caching: + context_summary = await summarize_text(context_text) + await save_to_session_cache( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) + return [completion] diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 81f0deaee..1367076e7 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -1,16 +1,19 @@ import os +import asyncio from typing import Any, Optional, List, Type from operator import itemgetter from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import save_to_session_cache from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.llm.prompts import render_prompt from cognee.infrastructure.llm import LLMGateway from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.shared.logging_utils import get_logger - +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig from cognee.tasks.temporal_graph.models import QueryInterval @@ -161,11 +164,36 @@ class TemporalRetriever(GraphCompletionRetriever): context = await self.get_context(query=query) if context: - completion = await generate_completion( - query=query, - context=context, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - ) + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if session_save: + context_summary, completion = await asyncio.gather( + summarize_text(context), + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ), + ) + else: + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + if session_save: + await save_to_session_cache( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) return [completion]