feat: adds user cognee interaction subgraph functionality

This commit is contained in:
hajdul88 2025-07-29 16:52:33 +02:00
parent f78af0cec3
commit e0221feb74
4 changed files with 105 additions and 4 deletions

View file

@ -114,4 +114,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer] return [answer]

View file

@ -122,4 +122,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
) )
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer] return [answer]

View file

@ -1,14 +1,20 @@
from typing import Any, Optional, Type, List from typing import Any, Optional, Type, List, Tuple
from collections import Counter from collections import Counter
import string import string
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.tasks.storage import add_data_points
from cognee.modules.engine.models.node_set import NodeSet
logger = get_logger("GraphCompletionRetriever") logger = get_logger("GraphCompletionRetriever")
@ -82,6 +88,73 @@ class GraphCompletionRetriever(BaseRetriever):
) )
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""
Saves a question and answer pair for later analysis or storage.
Parameters:
-----------
- question (str): The question text.
- answer (str): The answer text.
- context (str): The context text.
- triplets (List): A list of triples retrieved from the graph.
"""
nodeset_name = "Interactions"
interactions_node_set = NodeSet(
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
)
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
cognee_user_interaction = CogneeUserInteraction(
id=source_id,
question=question,
answer=answer,
context=context,
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction])
relationships = []
relationship_name = "used_graph_element_to_answer"
for triplet in triplets:
target_id_1 = UUID(triplet.node1.id)
target_id_2 = UUID(triplet.node2.id)
# Defined qa node to triplet node 1
relationships.append(
(
source_id,
target_id_1,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_1,
"ontology_valid": False,
},
)
)
# Defined qa node to triplet node 2
relationships.append(
(
source_id,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
async def get_triplets(self, query: str) -> list: async def get_triplets(self, query: str) -> list:
""" """
Retrieves relevant graph triplets based on a query string. Retrieves relevant graph triplets based on a query string.
@ -118,7 +191,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets return found_triplets
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> Tuple[str, List]:
""" """
Retrieves and resolves graph triplets into context based on a query. Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +212,9 @@ class GraphCompletionRetriever(BaseRetriever):
logger.warning("Empty context was provided to the completion") logger.warning("Empty context was provided to the completion")
return "" return ""
return await self.resolve_edges_to_text(triplets) context = await self.resolve_edges_to_text(triplets)
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """
@ -157,8 +232,10 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided. - Any: A generated completion based on the query and context provided.
""" """
triplets = None
if context is None: if context is None:
context = await self.get_context(query) context, triplets = await self.get_context(query)
completion = await generate_completion( completion = await generate_completion(
query=query, query=query,
@ -166,6 +243,12 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
if context and triplets:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion] return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):

View file

@ -0,0 +1,12 @@
from typing import Optional
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
class CogneeUserInteraction(DataPoint):
"""User - Cognee interaction"""
question: str
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None