feat: centralizes session caching in util function
This commit is contained in:
parent
8454389a7d
commit
0e4c4907e9
2 changed files with 54 additions and 14 deletions
|
|
@ -11,6 +11,7 @@ from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subcla
|
||||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||||
|
|
@ -18,10 +19,7 @@ from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.context_global_variables import session_user
|
from cognee.context_global_variables import session_user
|
||||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("GraphCompletionRetriever")
|
logger = get_logger("GraphCompletionRetriever")
|
||||||
cache_config = CacheConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionRetriever(BaseGraphRetriever):
|
class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
@ -163,7 +161,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
context_text = await resolve_edges_to_text(triplets)
|
context_text = await resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
cache_config = CacheConfig()
|
||||||
user = session_user.get()
|
user = session_user.get()
|
||||||
user_id = getattr(user, "id", None)
|
user_id = getattr(user, "id", None)
|
||||||
session_save = user_id and cache_config.caching
|
session_save = user_id and cache_config.caching
|
||||||
|
|
@ -194,17 +192,11 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
cache_engine = get_cache_engine()
|
context_summary=context_summary,
|
||||||
if session_id is None:
|
|
||||||
session_id = "default_session"
|
|
||||||
await cache_engine.add_qa(
|
|
||||||
str(user_id),
|
|
||||||
session_id=session_id,
|
|
||||||
question=query,
|
|
||||||
context=context_summary,
|
|
||||||
answer=completion,
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
48
cognee/modules/retrieval/utils/session_cache.py
Normal file
48
cognee/modules/retrieval/utils/session_cache.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
from typing import Optional
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def save_to_session_cache(
|
||||||
|
query: str,
|
||||||
|
context_summary: str,
|
||||||
|
answer: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Saves Q&A interaction to the session cache if user is authenticated and caching is enabled.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The user's query/question.
|
||||||
|
- context_summary (str): Summarized context used for generating the answer.
|
||||||
|
- answer (str): The generated answer/completion.
|
||||||
|
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- None: This function performs a side effect (saving to cache) and returns nothing.
|
||||||
|
"""
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
|
||||||
|
if not (user_id and cache_config.caching):
|
||||||
|
return
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||||
|
|
||||||
|
cache_engine = get_cache_engine()
|
||||||
|
if session_id is None:
|
||||||
|
session_id = "default_session"
|
||||||
|
|
||||||
|
await cache_engine.add_qa(
|
||||||
|
str(user_id),
|
||||||
|
session_id=session_id,
|
||||||
|
question=query,
|
||||||
|
context=context_summary,
|
||||||
|
answer=answer,
|
||||||
|
)
|
||||||
|
|
||||||
Loading…
Add table
Reference in a new issue