feat: adds session save to retrievers where actual completion is used

This commit is contained in:
hajdul88 2025-10-16 15:07:15 +02:00
parent 36fd44dab2
commit abe4dfa69a
5 changed files with 169 additions and 32 deletions

View file

@ -1,10 +1,14 @@
import asyncio
from typing import Any, Optional, List from typing import Any, Optional, List
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("entity_completion_retriever") logger = get_logger("entity_completion_retriever")
@ -109,12 +113,38 @@ class EntityCompletionRetriever(BaseRetriever):
if context is None: if context is None:
return ["No relevant entities found for the query."] return ["No relevant entities found for the query."]
completion = await generate_completion( # Check if we need to generate context summary for caching
query=query, cache_config = CacheConfig()
context=context, user = session_user.get()
user_prompt_path=self.user_prompt_path, user_id = getattr(user, "id", None)
system_prompt_path=self.system_prompt_path, session_save = user_id and cache_config.caching
)
if session_save:
context_summary, completion = await asyncio.gather(
summarize_text(str(context)),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if session_save:
await save_to_session_cache(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion] return [completion]
except Exception as e: except Exception as e:

View file

@ -1,11 +1,15 @@
import asyncio
from typing import Any, Optional from typing import Any, Optional
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("CompletionRetriever") logger = get_logger("CompletionRetriever")
@ -93,11 +97,38 @@ class CompletionRetriever(BaseRetriever):
if context is None: if context is None:
context = await self.get_context(query) context = await self.get_context(query)
completion = await generate_completion( # Check if we need to generate context summary for caching
query=query, cache_config = CacheConfig()
context=context, user = session_user.get()
user_prompt_path=self.user_prompt_path, user_id = getattr(user, "id", None)
system_prompt_path=self.system_prompt_path, session_save = user_id and cache_config.caching
system_prompt=self.system_prompt,
) if session_save:
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
if session_save:
await save_to_session_cache(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return completion return completion

View file

@ -1,8 +1,12 @@
import asyncio
from typing import Optional, List, Type from typing import Optional, List, Type
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger() logger = get_logger()
@ -118,17 +122,43 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1 round_idx += 1
completion = await generate_completion( # Check if we need to generate context summary for caching
query=query, cache_config = CacheConfig()
context=context_text, user = session_user.get()
user_prompt_path=self.user_prompt_path, user_id = getattr(user, "id", None)
system_prompt_path=self.system_prompt_path, session_save = user_id and cache_config.caching
system_prompt=self.system_prompt,
) if session_save:
context_summary, completion = await asyncio.gather(
summarize_text(context_text),
generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
if self.save_interaction and context_text and triplets and completion: if self.save_interaction and context_text and triplets and completion:
await self.save_qa( await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
if session_save:
await save_to_session_cache(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion] return [completion]

View file

@ -1,11 +1,15 @@
import asyncio
from typing import Optional, List, Type, Any from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger() logger = get_logger()
@ -142,4 +146,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
# Save to session cache
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
if user_id and cache_config.caching:
context_summary = await summarize_text(context_text)
await save_to_session_cache(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion] return [completion]

View file

@ -1,16 +1,19 @@
import os import os
import asyncio
from typing import Any, Optional, List, Type from typing import Any, Optional, List, Type
from operator import itemgetter from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.prompts import render_prompt from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
from cognee.tasks.temporal_graph.models import QueryInterval from cognee.tasks.temporal_graph.models import QueryInterval
@ -161,11 +164,36 @@ class TemporalRetriever(GraphCompletionRetriever):
context = await self.get_context(query=query) context = await self.get_context(query=query)
if context: if context:
completion = await generate_completion( # Check if we need to generate context summary for caching
query=query, cache_config = CacheConfig()
context=context, user = session_user.get()
user_prompt_path=self.user_prompt_path, user_id = getattr(user, "id", None)
system_prompt_path=self.system_prompt_path, session_save = user_id and cache_config.caching
)
if session_save:
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if session_save:
await save_to_session_cache(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion] return [completion]