feat: adds session save to retrievers where actual completion is used
This commit is contained in:
parent
36fd44dab2
commit
abe4dfa69a
5 changed files with 169 additions and 32 deletions
|
|
@ -1,10 +1,14 @@
|
||||||
|
import asyncio
|
||||||
from typing import Any, Optional, List
|
from typing import Any, Optional, List
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||||
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
|
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("entity_completion_retriever")
|
logger = get_logger("entity_completion_retriever")
|
||||||
|
|
@ -109,12 +113,38 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
if context is None:
|
if context is None:
|
||||||
return ["No relevant entities found for the query."]
|
return ["No relevant entities found for the query."]
|
||||||
|
|
||||||
completion = await generate_completion(
|
# Check if we need to generate context summary for caching
|
||||||
query=query,
|
cache_config = CacheConfig()
|
||||||
context=context,
|
user = session_user.get()
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_id = getattr(user, "id", None)
|
||||||
system_prompt_path=self.system_prompt_path,
|
session_save = user_id and cache_config.caching
|
||||||
)
|
|
||||||
|
if session_save:
|
||||||
|
context_summary, completion = await asyncio.gather(
|
||||||
|
summarize_text(str(context)),
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session_save:
|
||||||
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
|
import asyncio
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
logger = get_logger("CompletionRetriever")
|
logger = get_logger("CompletionRetriever")
|
||||||
|
|
||||||
|
|
@ -93,11 +97,38 @@ class CompletionRetriever(BaseRetriever):
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
completion = await generate_completion(
|
# Check if we need to generate context summary for caching
|
||||||
query=query,
|
cache_config = CacheConfig()
|
||||||
context=context,
|
user = session_user.get()
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_id = getattr(user, "id", None)
|
||||||
system_prompt_path=self.system_prompt_path,
|
session_save = user_id and cache_config.caching
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
)
|
if session_save:
|
||||||
|
context_summary, completion = await asyncio.gather(
|
||||||
|
summarize_text(context),
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session_save:
|
||||||
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
|
import asyncio
|
||||||
from typing import Optional, List, Type
|
from typing import Optional, List, Type
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -118,17 +122,43 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
round_idx += 1
|
round_idx += 1
|
||||||
|
|
||||||
completion = await generate_completion(
|
# Check if we need to generate context summary for caching
|
||||||
query=query,
|
cache_config = CacheConfig()
|
||||||
context=context_text,
|
user = session_user.get()
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_id = getattr(user, "id", None)
|
||||||
system_prompt_path=self.system_prompt_path,
|
session_save = user_id and cache_config.caching
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
)
|
if session_save:
|
||||||
|
context_summary, completion = await asyncio.gather(
|
||||||
|
summarize_text(context_text),
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
if self.save_interaction and context_text and triplets and completion:
|
if self.save_interaction and context_text and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if session_save:
|
||||||
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
|
import asyncio
|
||||||
from typing import Optional, List, Type, Any
|
from typing import Optional, List, Type, Any
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -142,4 +146,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save to session cache
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
|
||||||
|
if user_id and cache_config.caching:
|
||||||
|
context_summary = await summarize_text(context_text)
|
||||||
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,19 @@
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from typing import Any, Optional, List, Type
|
from typing import Any, Optional, List, Type
|
||||||
|
|
||||||
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt
|
from cognee.infrastructure.llm.prompts import render_prompt
|
||||||
from cognee.infrastructure.llm import LLMGateway
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
from cognee.tasks.temporal_graph.models import QueryInterval
|
from cognee.tasks.temporal_graph.models import QueryInterval
|
||||||
|
|
||||||
|
|
@ -161,11 +164,36 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
context = await self.get_context(query=query)
|
context = await self.get_context(query=query)
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
completion = await generate_completion(
|
# Check if we need to generate context summary for caching
|
||||||
query=query,
|
cache_config = CacheConfig()
|
||||||
context=context,
|
user = session_user.get()
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_id = getattr(user, "id", None)
|
||||||
system_prompt_path=self.system_prompt_path,
|
session_save = user_id and cache_config.caching
|
||||||
)
|
|
||||||
|
if session_save:
|
||||||
|
context_summary, completion = await asyncio.gather(
|
||||||
|
summarize_text(context),
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session_save:
|
||||||
|
await save_to_session_cache(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=completion,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue