feat: basic session behavior (only graph completion now just to save)
This commit is contained in:
parent
df6de7b246
commit
0aa64403c5
1 changed files with 39 additions and 8 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
from typing import Any, Optional, Type, List
|
from typing import Any, Optional, Type, List
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
|
@ -9,14 +10,18 @@ from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.context_global_variables import session_user
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("GraphCompletionRetriever")
|
logger = get_logger("GraphCompletionRetriever")
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionRetriever(BaseGraphRetriever):
|
class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
@ -132,6 +137,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generates a completion using graph connections context based on a query.
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
@ -155,19 +161,44 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
context_text = await resolve_edges_to_text(triplets)
|
context_text = await resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
completion = await generate_completion(
|
# Check if we need to generate context summary for caching
|
||||||
query=query,
|
user = session_user.get()
|
||||||
context=context_text,
|
user_id = getattr(user, "id", None)
|
||||||
user_prompt_path=self.user_prompt_path,
|
session_save = user_id and cache_config.caching
|
||||||
system_prompt_path=self.system_prompt_path,
|
|
||||||
system_prompt=self.system_prompt,
|
if session_save:
|
||||||
)
|
context_summary, completion = await asyncio.gather(
|
||||||
|
summarize_text(context_text),
|
||||||
|
generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if session_save:
|
||||||
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||||
|
cache_engine = get_cache_engine()
|
||||||
|
if session_id is None:
|
||||||
|
session_id = 'default_session'
|
||||||
|
await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion)
|
||||||
|
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
||||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue