From 711c805c83865d0586d711492b6ea461ed48fddb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:14:06 +0200 Subject: [PATCH] feat: adds cognee-user interactions to search --- cognee/api/v1/search/search.py | 4 + ..._completion_context_extension_retriever.py | 11 ++- .../graph_completion_cot_retriever.py | 19 ++-- .../retrieval/graph_completion_retriever.py | 90 ++++++++++++++++++- .../graph_summary_completion_retriever.py | 2 + .../retrieval/utils/extract_uuid_from_node.py | 18 ++++ cognee/modules/retrieval/utils/models.py | 36 ++++++++ cognee/modules/search/methods/search.py | 20 ++++- cognee/tasks/storage/add_data_points.py | 35 +++++++- 9 files changed, 217 insertions(+), 18 deletions(-) create mode 100644 cognee/modules/retrieval/utils/extract_uuid_from_node.py create mode 100644 cognee/modules/retrieval/utils/models.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 66ce48cc2..118412566 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -19,6 +19,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -107,6 +108,8 @@ async def search( node_name: Filter results to specific named entities (for targeted search). + save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not + Returns: list: Search results in format determined by query_type: @@ -189,6 +192,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) return filtered_search_results diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4027646c1..d05e6b4fa 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -29,6 +29,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -36,6 +37,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) async def get_completion( @@ -105,11 +107,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - answer = await generate_completion( + completion = await generate_completion( query=query, context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) - return [answer] + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index b3e3bfbd4..032dccf9e 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -35,6 +35,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -42,6 +43,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path @@ -75,7 +77,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): """ followup_question = "" triplets = [] - answer = [""] + completion = [""] for round_idx in range(max_iter + 1): if round_idx == 0: @@ -85,15 +87,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_triplets(followup_question) context = await self.resolve_edges_to_text(list(set(triplets))) - answer = await generate_completion( + completion = await generate_completion( query=query, context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}") + logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: - valid_args = {"query": query, "answer": answer, "context": context} + valid_args = {"query": query, "answer": completion, "context": context} valid_user_prompt = LLMGateway.render_prompt( filename=self.validation_user_prompt_path, context=valid_args ) @@ -106,7 +108,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): system_prompt=valid_system_prompt, response_model=str, ) - followup_args = {"query": query, "answer": answer, "reasoning": reasoning} + followup_args = {"query": query, "answer": completion, "reasoning": reasoning} followup_prompt = LLMGateway.render_prompt( filename=self.followup_user_prompt_path, context=followup_args ) @@ -121,4 +123,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" ) - return [answer] + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 9727f2c35..a8cdee3ab 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,14 +1,20 @@ -from typing import Any, Optional, Type, List +from typing import Any, Optional, Type, List, Coroutine from collections import Counter +from uuid import NAMESPACE_OID, uuid5 import string from cognee.infrastructure.engine import DataPoint +from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.shared.logging_utils import get_logger +from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node +from cognee.modules.retrieval.utils.models import CogneeUserInteraction +from cognee.modules.engine.models.node_set import NodeSet +from cognee.infrastructure.databases.graph import get_graph_engine logger = get_logger("GraphCompletionRetriever") @@ -33,8 +39,10 @@ class GraphCompletionRetriever(BaseRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" + self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 5 @@ -118,7 +126,7 @@ class GraphCompletionRetriever(BaseRetriever): return found_triplets - async def get_context(self, query: str) -> str: + async def get_context(self, query: str) -> str | tuple[str, list]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,7 +147,9 @@ class GraphCompletionRetriever(BaseRetriever): logger.warning("Empty context was provided to the completion") return "" - return await self.resolve_edges_to_text(triplets) + context = await self.resolve_edges_to_text(triplets) + + return context, triplets async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ @@ -157,8 +167,10 @@ class GraphCompletionRetriever(BaseRetriever): - Any: A generated completion based on the query and context provided. """ + triplets = None + if context is None: - context = await self.get_context(query) + context, triplets = await self.get_context(query) completion = await generate_completion( query=query, @@ -166,6 +178,12 @@ class GraphCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) + + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + return [completion] def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): @@ -187,3 +205,67 @@ class GraphCompletionRetriever(BaseRetriever): first_n_words = text.split()[:first_n_words] top_n_words = self._top_n_words(text, top_n=top_n_words) return f"{' '.join(first_n_words)}... [{top_n_words}]" + + async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: + """ + Saves a question and answer pair for later analysis or storage. + Parameters: + ----------- + - question (str): The question text. + - answer (str): The answer text. + - context (str): The context text. + - triplets (List): A list of triples retrieved from the graph. + """ + nodeset_name = "Interactions" + interactions_node_set = NodeSet( + id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name + ) + source_id = uuid5(NAMESPACE_OID, name=(question + answer + context)) + + cognee_user_interaction = CogneeUserInteraction( + id=source_id, + question=question, + answer=answer, + context=context, + belongs_to_set=interactions_node_set, + ) + + await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False) + + relationships = [] + relationship_name = "used_graph_element_to_answer" + for triplet in triplets: + target_id_1 = extract_uuid_from_node(triplet.node1) + target_id_2 = extract_uuid_from_node(triplet.node2) + if target_id_1 and target_id_2: + relationships.append( + ( + source_id, + target_id_1, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_1, + "ontology_valid": False, + }, + ) + ) + + relationships.append( + ( + source_id, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 803fb5993..d344ebd26 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -24,6 +24,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -32,6 +33,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/utils/extract_uuid_from_node.py b/cognee/modules/retrieval/utils/extract_uuid_from_node.py new file mode 100644 index 000000000..23a519970 --- /dev/null +++ b/cognee/modules/retrieval/utils/extract_uuid_from_node.py @@ -0,0 +1,18 @@ +from typing import Any, Optional +from uuid import UUID + + +def extract_uuid_from_node(node: Any) -> Optional[UUID]: + """ + Try to pull a UUID string out of node.id or node.properties['id'], + then return a UUID instance (or None if neither exists). + """ + id_str = None + if not id_str: + id_str = getattr(node, "id", None) + + if hasattr(node, "attributes") and not id_str: + id_str = node.attributes.get("id", None) + + id = UUID(id_str) if isinstance(id_str, str) else None + return id diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py new file mode 100644 index 000000000..69ffa9a5f --- /dev/null +++ b/cognee/modules/retrieval/utils/models.py @@ -0,0 +1,36 @@ +from typing import Optional +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.modules.engine.models.node_set import NodeSet +from enum import Enum +from pydantic import BaseModel, ValidationError + + +class CogneeUserInteraction(DataPoint): + """User - Cognee interaction""" + + question: str + answer: str + context: str + belongs_to_set: Optional[NodeSet] = None + + +class CogneeUserFeedback(DataPoint): + """User - Cognee Feedback""" + + feedback: str + sentiment: str + belongs_to_set: Optional[NodeSet] = None + + +class UserFeedbackSentiment(str, Enum): + """User - User feedback sentiment""" + + positive = "positive" + negative = "negative" + neutral = "neutral" + + +class UserFeedbackEvaluation(BaseModel): + """User - User feedback evaluation""" + + evaluation: UserFeedbackSentiment diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 365920019..2e66a2461 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -39,6 +39,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """ @@ -58,7 +59,7 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text, query_type, user, dataset_ids, system_prompt_path, top_k + query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction ) query = await log_query(query_text, query_type.value, user.id) @@ -71,6 +72,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) await log_result( @@ -92,6 +94,7 @@ async def specific_search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -105,24 +108,28 @@ async def specific_search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, @@ -154,6 +161,7 @@ async def authorized_search( dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, + save_interaction: bool = False, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -167,7 +175,7 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, query_text, query_type, user, system_prompt_path, top_k + search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -182,6 +190,7 @@ async def specific_search_by_context( user: User, system_prompt_path: str, top_k: int, + save_interaction: bool = False, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -192,7 +201,12 @@ async def specific_search_by_context( # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( - query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k + query_type, + query_text, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, + save_interaction=save_interaction, ) return { "search_result": search_results, diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index 9b5c36c37..28ec28a30 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -7,7 +7,36 @@ from .index_data_points import index_data_points from .index_graph_edges import index_graph_edges -async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: +async def add_data_points( + data_points: List[DataPoint], update_edge_collection=True +) -> List[DataPoint]: + """ + Add a batch of data points to the graph database by extracting nodes and edges, + deduplicating them, and indexing them for retrieval. + + This function parallelizes the graph extraction for each data point, + merges the resulting nodes and edges, and ensures uniqueness before + committing them to the underlying graph engine. It also updates the + associated retrieval indices for nodes and (optionally) edges. + + Args: + data_points (List[DataPoint]): + A list of data points to process and insert into the graph. + update_edge_collection (bool, optional): + Whether to update the edge index after adding edges. + Defaults to True. + + Returns: + List[DataPoint]: + The original list of data points after processing and insertion. + + Side Effects: + - Calls `get_graph_from_model` concurrently for each data point. + - Deduplicates nodes and edges across all results. + - Updates the node index via `index_data_points`. + - Inserts nodes and edges into the graph engine. + - Optionally updates the edge index via `index_graph_edges`. + """ nodes = [] edges = [] @@ -40,7 +69,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: await graph_engine.add_nodes(nodes) await graph_engine.add_edges(edges) - # This step has to happen after adding nodes and edges because we query the graph. - await index_graph_edges() + if update_edge_collection: + await index_graph_edges() return data_points