From e0221feb74a4950de5609314944fe43700adcf34 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 29 Jul 2025 16:52:33 +0200 Subject: [PATCH] feat: adds user cognee interaction subgraph functionality --- ..._completion_context_extension_retriever.py | 3 + .../graph_completion_cot_retriever.py | 3 + .../retrieval/graph_completion_retriever.py | 91 ++++++++++++++++++- cognee/modules/retrieval/utils/models.py | 12 +++ 4 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 cognee/modules/retrieval/utils/models.py diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 2479a454d..ab90a5839 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -114,4 +114,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, ) + if context and triplets: + await self.save_qa(question=query, answer=answer, context=context, triplets=triplets) + return [answer] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 95eb1a9b6..5cc81b49c 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -122,4 +122,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" ) + if context and triplets: + await self.save_qa(question=query, answer=answer, context=context, triplets=triplets) + return [answer] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 9727f2c35..350b419c5 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,14 +1,20 @@ -from typing import Any, Optional, Type, List +from typing import Any, Optional, Type, List, Tuple from collections import Counter import string +from uuid import NAMESPACE_OID, uuid5, UUID + +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS +from cognee.modules.retrieval.utils.models import CogneeUserInteraction from cognee.shared.logging_utils import get_logger +from cognee.tasks.storage import add_data_points +from cognee.modules.engine.models.node_set import NodeSet logger = get_logger("GraphCompletionRetriever") @@ -82,6 +88,73 @@ class GraphCompletionRetriever(BaseRetriever): ) return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" + async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: + """ + Saves a question and answer pair for later analysis or storage. + + Parameters: + ----------- + - question (str): The question text. + - answer (str): The answer text. + - context (str): The context text. + - triplets (List): A list of triples retrieved from the graph. + """ + nodeset_name = "Interactions" + interactions_node_set = NodeSet( + id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name + ) + source_id = uuid5(NAMESPACE_OID, name=(question + answer + context)) + + cognee_user_interaction = CogneeUserInteraction( + id=source_id, + question=question, + answer=answer, + context=context, + belongs_to_set=interactions_node_set, + ) + + await add_data_points(data_points=[cognee_user_interaction]) + + relationships = [] + relationship_name = "used_graph_element_to_answer" + for triplet in triplets: + target_id_1 = UUID(triplet.node1.id) + target_id_2 = UUID(triplet.node2.id) + + # Defined qa node to triplet node 1 + relationships.append( + ( + source_id, + target_id_1, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_1, + "ontology_valid": False, + }, + ) + ) + + # Defined qa node to triplet node 2 + relationships.append( + ( + source_id, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) + async def get_triplets(self, query: str) -> list: """ Retrieves relevant graph triplets based on a query string. @@ -118,7 +191,7 @@ class GraphCompletionRetriever(BaseRetriever): return found_triplets - async def get_context(self, query: str) -> str: + async def get_context(self, query: str) -> Tuple[str, List]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,7 +212,9 @@ class GraphCompletionRetriever(BaseRetriever): logger.warning("Empty context was provided to the completion") return "" - return await self.resolve_edges_to_text(triplets) + context = await self.resolve_edges_to_text(triplets) + + return context, triplets async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ @@ -157,8 +232,10 @@ class GraphCompletionRetriever(BaseRetriever): - Any: A generated completion based on the query and context provided. """ + triplets = None + if context is None: - context = await self.get_context(query) + context, triplets = await self.get_context(query) completion = await generate_completion( query=query, @@ -166,6 +243,12 @@ class GraphCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) + + if context and triplets: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + return [completion] def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py new file mode 100644 index 000000000..895c280ae --- /dev/null +++ b/cognee/modules/retrieval/utils/models.py @@ -0,0 +1,12 @@ +from typing import Optional +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.modules.engine.models.node_set import NodeSet + + +class CogneeUserInteraction(DataPoint): + """User - Cognee interaction""" + + question: str + answer: str + context: str + belongs_to_set: Optional[NodeSet] = None