feat: adds user cognee interaction subgraph functionality

This commit is contained in:
hajdul88 2025-07-29 16:52:33 +02:00
parent f78af0cec3
commit e0221feb74
4 changed files with 105 additions and 4 deletions

View file

@ -114,4 +114,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
)
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer]

View file

@ -122,4 +122,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer]

View file

@ -1,14 +1,20 @@
from typing import Any, Optional, Type, List
from typing import Any, Optional, Type, List, Tuple
from collections import Counter
import string
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.shared.logging_utils import get_logger
from cognee.tasks.storage import add_data_points
from cognee.modules.engine.models.node_set import NodeSet
logger = get_logger("GraphCompletionRetriever")
@ -82,6 +88,73 @@ class GraphCompletionRetriever(BaseRetriever):
)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""
Saves a question and answer pair for later analysis or storage.
Parameters:
-----------
- question (str): The question text.
- answer (str): The answer text.
- context (str): The context text.
- triplets (List): A list of triples retrieved from the graph.
"""
nodeset_name = "Interactions"
interactions_node_set = NodeSet(
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
)
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
cognee_user_interaction = CogneeUserInteraction(
id=source_id,
question=question,
answer=answer,
context=context,
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction])
relationships = []
relationship_name = "used_graph_element_to_answer"
for triplet in triplets:
target_id_1 = UUID(triplet.node1.id)
target_id_2 = UUID(triplet.node2.id)
# Defined qa node to triplet node 1
relationships.append(
(
source_id,
target_id_1,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_1,
"ontology_valid": False,
},
)
)
# Defined qa node to triplet node 2
relationships.append(
(
source_id,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
async def get_triplets(self, query: str) -> list:
"""
Retrieves relevant graph triplets based on a query string.
@ -118,7 +191,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets
async def get_context(self, query: str) -> str:
async def get_context(self, query: str) -> Tuple[str, List]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +212,9 @@ class GraphCompletionRetriever(BaseRetriever):
logger.warning("Empty context was provided to the completion")
return ""
return await self.resolve_edges_to_text(triplets)
context = await self.resolve_edges_to_text(triplets)
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
@ -157,8 +232,10 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided.
"""
triplets = None
if context is None:
context = await self.get_context(query)
context, triplets = await self.get_context(query)
completion = await generate_completion(
query=query,
@ -166,6 +243,12 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if context and triplets:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):

View file

@ -0,0 +1,12 @@
from typing import Optional
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
class CogneeUserInteraction(DataPoint):
"""User - Cognee interaction"""
question: str
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None