From 711c805c83865d0586d711492b6ea461ed48fddb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:14:06 +0200 Subject: [PATCH 01/17] feat: adds cognee-user interactions to search --- cognee/api/v1/search/search.py | 4 + ..._completion_context_extension_retriever.py | 11 ++- .../graph_completion_cot_retriever.py | 19 ++-- .../retrieval/graph_completion_retriever.py | 90 ++++++++++++++++++- .../graph_summary_completion_retriever.py | 2 + .../retrieval/utils/extract_uuid_from_node.py | 18 ++++ cognee/modules/retrieval/utils/models.py | 36 ++++++++ cognee/modules/search/methods/search.py | 20 ++++- cognee/tasks/storage/add_data_points.py | 35 +++++++- 9 files changed, 217 insertions(+), 18 deletions(-) create mode 100644 cognee/modules/retrieval/utils/extract_uuid_from_node.py create mode 100644 cognee/modules/retrieval/utils/models.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 66ce48cc2..118412566 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -19,6 +19,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -107,6 +108,8 @@ async def search( node_name: Filter results to specific named entities (for targeted search). + save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not + Returns: list: Search results in format determined by query_type: @@ -189,6 +192,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) return filtered_search_results diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4027646c1..d05e6b4fa 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -29,6 +29,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -36,6 +37,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) async def get_completion( @@ -105,11 +107,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - answer = await generate_completion( + completion = await generate_completion( query=query, context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) - return [answer] + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index b3e3bfbd4..032dccf9e 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -35,6 +35,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -42,6 +43,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path @@ -75,7 +77,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): """ followup_question = "" triplets = [] - answer = [""] + completion = [""] for round_idx in range(max_iter + 1): if round_idx == 0: @@ -85,15 +87,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_triplets(followup_question) context = await self.resolve_edges_to_text(list(set(triplets))) - answer = await generate_completion( + completion = await generate_completion( query=query, context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}") + logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: - valid_args = {"query": query, "answer": answer, "context": context} + valid_args = {"query": query, "answer": completion, "context": context} valid_user_prompt = LLMGateway.render_prompt( filename=self.validation_user_prompt_path, context=valid_args ) @@ -106,7 +108,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): system_prompt=valid_system_prompt, response_model=str, ) - followup_args = {"query": query, "answer": answer, "reasoning": reasoning} + followup_args = {"query": query, "answer": completion, "reasoning": reasoning} followup_prompt = LLMGateway.render_prompt( filename=self.followup_user_prompt_path, context=followup_args ) @@ -121,4 +123,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" ) - return [answer] + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 9727f2c35..a8cdee3ab 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,14 +1,20 @@ -from typing import Any, Optional, Type, List +from typing import Any, Optional, Type, List, Coroutine from collections import Counter +from uuid import NAMESPACE_OID, uuid5 import string from cognee.infrastructure.engine import DataPoint +from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.shared.logging_utils import get_logger +from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node +from cognee.modules.retrieval.utils.models import CogneeUserInteraction +from cognee.modules.engine.models.node_set import NodeSet +from cognee.infrastructure.databases.graph import get_graph_engine logger = get_logger("GraphCompletionRetriever") @@ -33,8 +39,10 @@ class GraphCompletionRetriever(BaseRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" + self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 5 @@ -118,7 +126,7 @@ class GraphCompletionRetriever(BaseRetriever): return found_triplets - async def get_context(self, query: str) -> str: + async def get_context(self, query: str) -> str | tuple[str, list]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,7 +147,9 @@ class GraphCompletionRetriever(BaseRetriever): logger.warning("Empty context was provided to the completion") return "" - return await self.resolve_edges_to_text(triplets) + context = await self.resolve_edges_to_text(triplets) + + return context, triplets async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ @@ -157,8 +167,10 @@ class GraphCompletionRetriever(BaseRetriever): - Any: A generated completion based on the query and context provided. """ + triplets = None + if context is None: - context = await self.get_context(query) + context, triplets = await self.get_context(query) completion = await generate_completion( query=query, @@ -166,6 +178,12 @@ class GraphCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, ) + + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=completion, context=context, triplets=triplets + ) + return [completion] def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): @@ -187,3 +205,67 @@ class GraphCompletionRetriever(BaseRetriever): first_n_words = text.split()[:first_n_words] top_n_words = self._top_n_words(text, top_n=top_n_words) return f"{' '.join(first_n_words)}... [{top_n_words}]" + + async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: + """ + Saves a question and answer pair for later analysis or storage. + Parameters: + ----------- + - question (str): The question text. + - answer (str): The answer text. + - context (str): The context text. + - triplets (List): A list of triples retrieved from the graph. + """ + nodeset_name = "Interactions" + interactions_node_set = NodeSet( + id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name + ) + source_id = uuid5(NAMESPACE_OID, name=(question + answer + context)) + + cognee_user_interaction = CogneeUserInteraction( + id=source_id, + question=question, + answer=answer, + context=context, + belongs_to_set=interactions_node_set, + ) + + await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False) + + relationships = [] + relationship_name = "used_graph_element_to_answer" + for triplet in triplets: + target_id_1 = extract_uuid_from_node(triplet.node1) + target_id_2 = extract_uuid_from_node(triplet.node2) + if target_id_1 and target_id_2: + relationships.append( + ( + source_id, + target_id_1, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_1, + "ontology_valid": False, + }, + ) + ) + + relationships.append( + ( + source_id, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 803fb5993..d344ebd26 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -24,6 +24,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -32,6 +33,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/utils/extract_uuid_from_node.py b/cognee/modules/retrieval/utils/extract_uuid_from_node.py new file mode 100644 index 000000000..23a519970 --- /dev/null +++ b/cognee/modules/retrieval/utils/extract_uuid_from_node.py @@ -0,0 +1,18 @@ +from typing import Any, Optional +from uuid import UUID + + +def extract_uuid_from_node(node: Any) -> Optional[UUID]: + """ + Try to pull a UUID string out of node.id or node.properties['id'], + then return a UUID instance (or None if neither exists). + """ + id_str = None + if not id_str: + id_str = getattr(node, "id", None) + + if hasattr(node, "attributes") and not id_str: + id_str = node.attributes.get("id", None) + + id = UUID(id_str) if isinstance(id_str, str) else None + return id diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py new file mode 100644 index 000000000..69ffa9a5f --- /dev/null +++ b/cognee/modules/retrieval/utils/models.py @@ -0,0 +1,36 @@ +from typing import Optional +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.modules.engine.models.node_set import NodeSet +from enum import Enum +from pydantic import BaseModel, ValidationError + + +class CogneeUserInteraction(DataPoint): + """User - Cognee interaction""" + + question: str + answer: str + context: str + belongs_to_set: Optional[NodeSet] = None + + +class CogneeUserFeedback(DataPoint): + """User - Cognee Feedback""" + + feedback: str + sentiment: str + belongs_to_set: Optional[NodeSet] = None + + +class UserFeedbackSentiment(str, Enum): + """User - User feedback sentiment""" + + positive = "positive" + negative = "negative" + neutral = "neutral" + + +class UserFeedbackEvaluation(BaseModel): + """User - User feedback evaluation""" + + evaluation: UserFeedbackSentiment diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 365920019..2e66a2461 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -39,6 +39,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ): """ @@ -58,7 +59,7 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text, query_type, user, dataset_ids, system_prompt_path, top_k + query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction ) query = await log_query(query_text, query_type.value, user.id) @@ -71,6 +72,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ) await log_result( @@ -92,6 +94,7 @@ async def specific_search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + save_interaction: bool = False, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -105,24 +108,28 @@ async def specific_search( top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, @@ -154,6 +161,7 @@ async def authorized_search( dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, + save_interaction: bool = False, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -167,7 +175,7 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, query_text, query_type, user, system_prompt_path, top_k + search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -182,6 +190,7 @@ async def specific_search_by_context( user: User, system_prompt_path: str, top_k: int, + save_interaction: bool = False, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -192,7 +201,12 @@ async def specific_search_by_context( # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( - query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k + query_type, + query_text, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, + save_interaction=save_interaction, ) return { "search_result": search_results, diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index 9b5c36c37..28ec28a30 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -7,7 +7,36 @@ from .index_data_points import index_data_points from .index_graph_edges import index_graph_edges -async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: +async def add_data_points( + data_points: List[DataPoint], update_edge_collection=True +) -> List[DataPoint]: + """ + Add a batch of data points to the graph database by extracting nodes and edges, + deduplicating them, and indexing them for retrieval. + + This function parallelizes the graph extraction for each data point, + merges the resulting nodes and edges, and ensures uniqueness before + committing them to the underlying graph engine. It also updates the + associated retrieval indices for nodes and (optionally) edges. + + Args: + data_points (List[DataPoint]): + A list of data points to process and insert into the graph. + update_edge_collection (bool, optional): + Whether to update the edge index after adding edges. + Defaults to True. + + Returns: + List[DataPoint]: + The original list of data points after processing and insertion. + + Side Effects: + - Calls `get_graph_from_model` concurrently for each data point. + - Deduplicates nodes and edges across all results. + - Updates the node index via `index_data_points`. + - Inserts nodes and edges into the graph engine. + - Optionally updates the edge index via `index_graph_edges`. + """ nodes = [] edges = [] @@ -40,7 +69,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: await graph_engine.add_nodes(nodes) await graph_engine.add_edges(edges) - # This step has to happen after adding nodes and edges because we query the graph. - await index_graph_edges() + if update_edge_collection: + await index_graph_edges() return data_points From dc637f70b0a826a50b5d5a4c2e3c3783f877a40a Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:23:57 +0200 Subject: [PATCH 02/17] fix: fixes add datapoints params --- cognee/tasks/storage/add_data_points.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index 27054ca66..a29f5a5f7 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -42,7 +42,7 @@ async def add_data_points( - Optionally updates the edge index via `index_graph_edges`. """ -async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: +async def add_data_points(data_points: List[DataPoint], update_edge_collection: bool = True) -> List[DataPoint]: if not isinstance(data_points, list): raise InvalidDataPointsInAddDataPointsError("data_points must be a list.") if not all(isinstance(dp, DataPoint) for dp in data_points): From fbb7d72461b798adfd19c07667d9f71a84e9b5b8 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:24:14 +0200 Subject: [PATCH 03/17] fix: ruff formatting --- cognee/tasks/storage/add_data_points.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index a29f5a5f7..68e6404f8 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -10,7 +10,6 @@ from cognee.tasks.storage.exceptions import ( ) - async def add_data_points( data_points: List[DataPoint], update_edge_collection=True ) -> List[DataPoint]: @@ -41,8 +40,11 @@ async def add_data_points( - Inserts nodes and edges into the graph engine. - Optionally updates the edge index via `index_graph_edges`. """ - -async def add_data_points(data_points: List[DataPoint], update_edge_collection: bool = True) -> List[DataPoint]: + + +async def add_data_points( + data_points: List[DataPoint], update_edge_collection: bool = True +) -> List[DataPoint]: if not isinstance(data_points, list): raise InvalidDataPointsInAddDataPointsError("data_points must be a list.") if not all(isinstance(dp, DataPoint) for dp in data_points): From 1d63da79232b1a782297567784a87dbdb31e157e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:26:45 +0200 Subject: [PATCH 04/17] chore: removes duplicated func def --- cognee/tasks/storage/add_data_points.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index 68e6404f8..1c7a4c2bc 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -11,7 +11,7 @@ from cognee.tasks.storage.exceptions import ( async def add_data_points( - data_points: List[DataPoint], update_edge_collection=True + data_points: List[DataPoint], update_edge_collection: bool = True ) -> List[DataPoint]: """ Add a batch of data points to the graph database by extracting nodes and edges, @@ -41,10 +41,6 @@ async def add_data_points( - Optionally updates the edge index via `index_graph_edges`. """ - -async def add_data_points( - data_points: List[DataPoint], update_edge_collection: bool = True -) -> List[DataPoint]: if not isinstance(data_points, list): raise InvalidDataPointsInAddDataPointsError("data_points must be a list.") if not all(isinstance(dp, DataPoint) for dp in data_points): From 78fb4158927e9c5cd6b7a0fa29bb7dafd23a1a93 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:40:33 +0200 Subject: [PATCH 05/17] chore: changes context return value in tests --- .../graph_completion_retriever_context_extension_test.py | 6 +++--- .../retrieval/graph_completion_retriever_cot_test.py | 6 +++--- .../modules/retrieval/graph_completion_retriever_test.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 3e7f6626f..26ae2f883 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -51,7 +51,7 @@ class TestGraphCompletionWithContextExtensionRetriever: retriever = GraphCompletionContextExtensionRetriever() - context = await retriever.get_context("Who works at Canva?") + context, _ = await retriever.get_context("Who works at Canva?") assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" @@ -129,7 +129,7 @@ class TestGraphCompletionWithContextExtensionRetriever: retriever = GraphCompletionContextExtensionRetriever(top_k=20) - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") print(context) @@ -167,7 +167,7 @@ class TestGraphCompletionWithContextExtensionRetriever: await setup() - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") assert context == "", "Context should be empty on an empty graph" answer = await retriever.get_completion("Who works at Figma?") diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index ff92dfd8f..be25299aa 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -47,7 +47,7 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever() - context = await retriever.get_context("Who works at Canva?") + context, _ = await retriever.get_context("Who works at Canva?") assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" @@ -124,7 +124,7 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever(top_k=20) - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") print(context) @@ -162,7 +162,7 @@ class TestGraphCompletionCoTRetriever: await setup() - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") assert context == "", "Context should be empty on an empty graph" answer = await retriever.get_completion("Who works at Figma?") diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 976b69e69..50784f94a 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -67,7 +67,7 @@ class TestGraphCompletionRetriever: retriever = GraphCompletionRetriever() - context = await retriever.get_context("Who works at Canva?") + context, _ = await retriever.get_context("Who works at Canva?") # Ensure the top-level sections are present assert "Nodes:" in context, "Missing 'Nodes:' section in context" @@ -191,7 +191,7 @@ class TestGraphCompletionRetriever: retriever = GraphCompletionRetriever(top_k=20) - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") print(context) @@ -222,5 +222,5 @@ class TestGraphCompletionRetriever: await setup() - context = await retriever.get_context("Who works at Figma?") + context, _ = await retriever.get_context("Who works at Figma?") assert context == "", "Context should be empty on an empty graph" From b6be61776a377e022ef683a7a6cbbff3767008d4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:50:21 +0200 Subject: [PATCH 06/17] fix: fixes tests --- cognee/modules/retrieval/graph_completion_retriever.py | 2 +- cognee/tests/unit/modules/search/search_methods_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index a8cdee3ab..c831d8550 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -145,7 +145,7 @@ class GraphCompletionRetriever(BaseRetriever): if len(triplets) == 0: logger.warning("Empty context was provided to the completion") - return "" + return "", triplets context = await self.resolve_edges_to_text(triplets) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 8e9afff1c..8645d965a 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -65,6 +65,7 @@ async def test_search( top_k=10, node_type=None, node_name=None, + save_interaction=False, ) # Verify result logging From 0529d4b87f90e02b0fd04a9420f7bfefd93579c4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:14:12 +0200 Subject: [PATCH 07/17] fix: fixes kuzu and neo4j tests --- cognee/tests/test_kuzu.py | 4 ++-- cognee/tests/test_neo4j.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index d60c6085e..16c7b9cf6 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -94,12 +94,12 @@ async def main(): await cognee.cognify([dataset_name]) - context_nonempty = await GraphCompletionRetriever( + context_nonempty, _ = await GraphCompletionRetriever( node_type=NodeSet, node_name=["first"], ).get_context("What is in the context?") - context_empty = await GraphCompletionRetriever( + context_empty, _ = await GraphCompletionRetriever( node_type=NodeSet, node_name=["nonexistent"], ).get_context("What is in the context?") diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index dcbb38963..d5ccbc19e 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -98,12 +98,12 @@ async def main(): await cognee.cognify([dataset_name]) - context_nonempty = await GraphCompletionRetriever( + context_nonempty, _ = await GraphCompletionRetriever( node_type=NodeSet, node_name=["first"], ).get_context("What is in the context?") - context_empty = await GraphCompletionRetriever( + context_empty, _ = await GraphCompletionRetriever( node_type=NodeSet, node_name=["nonexistent"], ).get_context("What is in the context?") From 9a46d145bbf40c407e66b2de776adf88f70a4556 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:53:47 +0200 Subject: [PATCH 08/17] chore: fix search db tests --- cognee/tests/test_search_db.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 9eafb5c0c..8a99ebd8f 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -44,16 +44,16 @@ async def main(): await cognee.cognify([dataset_name]) - context_gk = await GraphCompletionRetriever().get_context( + context_gk, _ = await GraphCompletionRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_cot = await GraphCompletionCotRetriever().get_context( + context_gk_cot, _ = await GraphCompletionCotRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context( + context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_sum = await GraphSummaryCompletionRetriever().get_context( + context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context( query="Next to which country is Germany located?" ) From fc43ac7a015226767bfb6d33d8b639435659565f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 17:54:49 +0200 Subject: [PATCH 09/17] feat: adds user feedback search type --- cognee/api/v1/search/search.py | 2 + .../databases/graph/kuzu/adapter.py | 23 ++++++ .../databases/graph/neo4j_driver/adapter.py | 23 ++++++ cognee/modules/retrieval/base_feedback.py | 11 +++ cognee/modules/retrieval/user_qa_feedback.py | 78 +++++++++++++++++++ cognee/modules/search/methods/search.py | 28 +++++-- cognee/modules/search/types/SearchType.py | 1 + 7 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 cognee/modules/retrieval/base_feedback.py create mode 100644 cognee/modules/retrieval/user_qa_feedback.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index b4499192d..f37f8ba6d 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -20,6 +20,7 @@ async def search( node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + last_k: Optional[int] = None, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -186,6 +187,7 @@ async def search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + last_k=last_k, ) return filtered_search_results diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 1bafb3754..12c15fb81 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -1631,3 +1631,26 @@ class KuzuAdapter(GraphDBInterface): """ result = await self.query(query) return [record[0] for record in result] if result else [] + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Retrieve the IDs of the most recent CogneeUserInteraction nodes. + Parameters: + ----------- + - limit (int): The maximum number of interaction IDs to return. + Returns: + -------- + - List[str]: A list of interaction IDs, sorted by created_at descending. + """ + + query = """ + MATCH (n) + WHERE n.type = 'CogneeUserInteraction' + RETURN n.id as id + ORDER BY n.created_at DESC + LIMIT $limit + """ + rows = await self.query(query, {"limit": limit}) + + id_list = [row[0] for row in rows] + return id_list diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index ea8072554..589848dc9 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1322,3 +1322,26 @@ class Neo4jAdapter(GraphDBInterface): """ result = await self.query(query) return [record["n"] for record in result] if result else [] + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Retrieve the IDs of the most recent CogneeUserInteraction nodes. + Parameters: + ----------- + - limit (int): The maximum number of interaction IDs to return. + Returns: + -------- + - List[str]: A list of interaction IDs, sorted by created_at descending. + """ + + query = """ + MATCH (n) + WHERE n.type = 'CogneeUserInteraction' + RETURN n.id as id + ORDER BY n.created_at DESC + LIMIT $limit + """ + rows = await self.query(query, {"limit": limit}) + + id_list = [row["id"] for row in rows if "id" in row] + return id_list diff --git a/cognee/modules/retrieval/base_feedback.py b/cognee/modules/retrieval/base_feedback.py new file mode 100644 index 000000000..62ad443ee --- /dev/null +++ b/cognee/modules/retrieval/base_feedback.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseFeedback(ABC): + """Base class for all user feedback operations.""" + + @abstractmethod + async def add_feedback(self, feedback_text: str) -> Any: + """Add user feedback to the system.""" + pass \ No newline at end of file diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py new file mode 100644 index 000000000..39f8c25f5 --- /dev/null +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -0,0 +1,78 @@ +from typing import Any, Optional, List + +from uuid import NAMESPACE_OID, uuid5, UUID +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.engine.models import NodeSet +from cognee.shared.logging_utils import get_logger +from cognee.modules.retrieval.base_feedback import BaseFeedback +from cognee.modules.retrieval.utils.models import CogneeUserFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation +from cognee.tasks.storage import add_data_points + +logger = get_logger("CompletionRetriever") + + +class UserQAFeedback(BaseFeedback): + """ + Interface for handling user feedback queries. + Public methods: + - get_context(query: str) -> str + - get_completion(query: str, context: Optional[Any] = None) -> Any + """ + + def __init__(self, last_k: Optional[int] = 1) -> None: + """Initialize retriever with optional custom prompt paths.""" + self.last_k = last_k + + async def add_feedback(self, feedback_text: str) -> List[str]: + + feedback_sentiment = await LLMGateway.acreate_structured_output( + text_input=feedback_text, + system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification", + response_model=UserFeedbackEvaluation, + ) + + graph_engine = await get_graph_engine() + last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k) + + nodeset_name = "UserQAFeedbacks" + feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name) + feedback_id = uuid5(NAMESPACE_OID, name=feedback_text) + + cognee_user_feedback = CogneeUserFeedback( + id=feedback_id, + feedback=feedback_text, + sentiment=feedback_sentiment.evaluation.value, + belongs_to_set=feedbacks_node_set, + ) + + await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False) + + relationships = [] + relationship_name = "gives_feedback_to" + + for interaction_id in last_interaction_ids: + target_id_1 = feedback_id + target_id_2 = UUID(interaction_id) + + if target_id_1 and target_id_2: + relationships.append( + ( + target_id_1, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": target_id_1, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) + + return [feedback_text] \ No newline at end of file diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index ba11d7f72..5f5371af7 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -3,6 +3,8 @@ import json import asyncio from uuid import UUID from typing import Callable, List, Optional, Type, Union + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.context_global_variables import set_database_global_context_variables from cognee.modules.retrieval.chunks_retriever import ChunksRetriever @@ -38,7 +40,8 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, ): """ @@ -58,7 +61,14 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction + query_text=query_text, + query_type=query_type, + user=user, + dataset_ids=dataset_ids, + system_prompt_path=system_prompt_path, + top_k=top_k, + save_interaction=save_interaction, + last_k=last_k ) query = await log_query(query_text, query_type.value, user.id) @@ -72,6 +82,7 @@ async def search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + last_k=last_k ) await log_result( @@ -93,7 +104,8 @@ async def specific_search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -133,6 +145,7 @@ async def specific_search( SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, + SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback, } # If the query type is FEELING_LUCKY, select the search type intelligently @@ -161,6 +174,7 @@ async def authorized_search( system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, save_interaction: bool = False, + last_k: Optional[int] = None, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -174,7 +188,7 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction + search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction, last_k=last_k ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -190,13 +204,14 @@ async def specific_search_by_context( system_prompt_path: str, top_k: int, save_interaction: bool = False, + last_k: Optional[int] = None, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. Not to be used outside of active access control mode. """ - async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k): + async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( @@ -206,6 +221,7 @@ async def specific_search_by_context( system_prompt_path=system_prompt_path, top_k=top_k, save_interaction=save_interaction, + last_k=last_k, ) return { "search_result": search_results, @@ -217,7 +233,7 @@ async def specific_search_by_context( tasks = [] for dataset in search_datasets: tasks.append( - _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k) + _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k) ) return await asyncio.gather(*tasks) diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 8248117e7..c1f0521b2 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -14,3 +14,4 @@ class SearchType(Enum): GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" FEELING_LUCKY = "FEELING_LUCKY" + FEEDBACK = "FEEDBACK" From 0fbe218eefb51349b9f583c480fce1417e2c42ed Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 18:36:04 +0200 Subject: [PATCH 10/17] chore: fixes ruff --- cognee/modules/retrieval/base_feedback.py | 2 +- cognee/modules/retrieval/user_qa_feedback.py | 3 +-- cognee/modules/search/methods/search.py | 21 +++++++++++++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cognee/modules/retrieval/base_feedback.py b/cognee/modules/retrieval/base_feedback.py index 62ad443ee..7da55d374 100644 --- a/cognee/modules/retrieval/base_feedback.py +++ b/cognee/modules/retrieval/base_feedback.py @@ -8,4 +8,4 @@ class BaseFeedback(ABC): @abstractmethod async def add_feedback(self, feedback_text: str) -> Any: """Add user feedback to the system.""" - pass \ No newline at end of file + pass diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py index 39f8c25f5..bca59e7f8 100644 --- a/cognee/modules/retrieval/user_qa_feedback.py +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -26,7 +26,6 @@ class UserQAFeedback(BaseFeedback): self.last_k = last_k async def add_feedback(self, feedback_text: str) -> List[str]: - feedback_sentiment = await LLMGateway.acreate_structured_output( text_input=feedback_text, system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification", @@ -75,4 +74,4 @@ class UserQAFeedback(BaseFeedback): graph_engine = await get_graph_engine() await graph_engine.add_edges(relationships) - return [feedback_text] \ No newline at end of file + return [feedback_text] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 5f5371af7..f5f2a793a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -68,7 +68,7 @@ async def search( system_prompt_path=system_prompt_path, top_k=top_k, save_interaction=save_interaction, - last_k=last_k + last_k=last_k, ) query = await log_query(query_text, query_type.value, user.id) @@ -82,7 +82,7 @@ async def search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, - last_k=last_k + last_k=last_k, ) await log_result( @@ -188,7 +188,14 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction, last_k=last_k + search_datasets, + query_text, + query_type, + user, + system_prompt_path, + top_k, + save_interaction, + last_k=last_k, ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -211,7 +218,9 @@ async def specific_search_by_context( Not to be used outside of active access control mode. """ - async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k): + async def _search_by_context( + dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + ): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( @@ -233,7 +242,9 @@ async def specific_search_by_context( tasks = [] for dataset in search_datasets: tasks.append( - _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k) + _search_by_context( + dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + ) ) return await asyncio.gather(*tasks) From 372181d8c1ca34915f81e49aa012d93d2793a4ac Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 09:43:34 +0200 Subject: [PATCH 11/17] fix: fixes unit test --- cognee/tests/unit/modules/search/search_methods_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 8645d965a..46995d087 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -66,6 +66,7 @@ async def test_search( node_type=None, node_name=None, save_interaction=False, + last_k=None, ) # Verify result logging From fcdee16f69e152bef56d8c020dee2a64d0080ef3 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:49:01 +0200 Subject: [PATCH 12/17] feat: adds kuzu and neo4j tests for feedback and interaction features --- cognee/tests/test_search_db.py | 82 ++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 8a99ebd8f..21dc1d3bf 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -4,6 +4,7 @@ import pathlib from dns.e164 import query import cognee +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( @@ -18,6 +19,7 @@ from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.engine.models import NodeSet +from collections import Counter logger = get_logger() @@ -112,18 +114,33 @@ async def main(): completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Next to which country is Germany located?", + save_interaction=True, ) completion_cot = await cognee.search( query_type=SearchType.GRAPH_COMPLETION_COT, query_text="Next to which country is Germany located?", + save_interaction=True, ) completion_ext = await cognee.search( query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, query_text="Next to which country is Germany located?", + save_interaction=True, ) + + feedback_sum_1 = await cognee.search( + query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1 + ) + completion_sum = await cognee.search( query_type=SearchType.GRAPH_SUMMARY_COMPLETION, query_text="Next to which country is Germany located?", + save_interaction=True, + ) + + feedback_sum_2 = await cognee.search( + query_type=SearchType.FEEDBACK, + query_text="This answer was great", + last_k=1, ) for name, completion in [ @@ -141,6 +158,71 @@ async def main(): f"{name}: expected 'netherlands' in result, got: {text!r}" ) + graph_engine = await get_graph_engine() + graph = await graph_engine.get_graph_data() + + type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) + + edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) + + # Assert there are exactly 4 CogneeUserInteraction nodes. + assert type_counts.get("CogneeUserInteraction", 0) == 4, ( + f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}" + ) + + # Assert there is exactly two CogneeUserFeedback nodes. + assert type_counts.get("CogneeUserFeedback", 0) == 2, ( + f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}" + ) + + # Assert there is exactly two NodeSet. + assert type_counts.get("NodeSet", 0) == 2, ( + f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}" + ) + + # Assert that there are at least 10 'used_graph_element_to_answer' edges. + assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, ( + f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}" + ) + + # Assert that there are exactly 2 'gives_feedback_to' edges. + assert edge_type_counts.get("gives_feedback_to", 0) == 2, ( + f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}" + ) + + # Assert that there are at least 6 'belongs_to_set' edges. + assert edge_type_counts.get("belongs_to_set", 0) == 6, ( + f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" + ) + + nodes = graph[0] + + required_fields_user_interaction = {"question", "answer", "context"} + required_fields_feedback = {"feedback", "sentiment"} + + for node_id, data in nodes: # nodes = your list + if data.get("type") == "CogneeUserInteraction": + assert required_fields_user_interaction.issubset(data.keys()), ( + f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" + ) + + for field in required_fields_user_interaction: + value = data[field] + assert isinstance(value, str) and value.strip(), ( + f"Node {node_id} has invalid value for '{field}': {value!r}" + ) + + if data.get("type") == "CogneeUserFeedback": + assert required_fields_feedback.issubset(data.keys()), ( + f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}" + ) + + for field in required_fields_feedback: + value = data[field] + assert isinstance(value, str) and value.strip(), ( + f"Node {node_id} has invalid value for '{field}': {value!r}" + ) + if __name__ == "__main__": import asyncio From 4e31ae7ffce67a32f8cbed47ba9b593dfd0a4256 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:50:23 +0200 Subject: [PATCH 13/17] chore: deletes unused var from search test --- cognee/tests/test_search_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 21dc1d3bf..f31c0076f 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -127,7 +127,7 @@ async def main(): save_interaction=True, ) - feedback_sum_1 = await cognee.search( + await cognee.search( query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1 ) @@ -137,7 +137,7 @@ async def main(): save_interaction=True, ) - feedback_sum_2 = await cognee.search( + await cognee.search( query_type=SearchType.FEEDBACK, query_text="This answer was great", last_k=1, From c6ec22a5a0d70960a068a8ec95cfd8fd5083e7e0 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:36:22 +0200 Subject: [PATCH 14/17] feat: adds scores to Feedback node --- cognee/modules/retrieval/user_qa_feedback.py | 3 ++- cognee/modules/retrieval/utils/models.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py index bca59e7f8..45b5b0ec0 100644 --- a/cognee/modules/retrieval/user_qa_feedback.py +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -28,7 +28,7 @@ class UserQAFeedback(BaseFeedback): async def add_feedback(self, feedback_text: str) -> List[str]: feedback_sentiment = await LLMGateway.acreate_structured_output( text_input=feedback_text, - system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification", + system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification and a corresponding score from -5 (worst negative) to 5 (best positive)", response_model=UserFeedbackEvaluation, ) @@ -43,6 +43,7 @@ class UserQAFeedback(BaseFeedback): id=feedback_id, feedback=feedback_text, sentiment=feedback_sentiment.evaluation.value, + score=feedback_sentiment.score, belongs_to_set=feedbacks_node_set, ) diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py index 69ffa9a5f..a71e881a9 100644 --- a/cognee/modules/retrieval/utils/models.py +++ b/cognee/modules/retrieval/utils/models.py @@ -2,8 +2,7 @@ from typing import Optional from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.modules.engine.models.node_set import NodeSet from enum import Enum -from pydantic import BaseModel, ValidationError - +from pydantic import BaseModel, Field, confloat class CogneeUserInteraction(DataPoint): """User - Cognee interaction""" @@ -19,6 +18,7 @@ class CogneeUserFeedback(DataPoint): feedback: str sentiment: str + score: float belongs_to_set: Optional[NodeSet] = None @@ -32,5 +32,8 @@ class UserFeedbackSentiment(str, Enum): class UserFeedbackEvaluation(BaseModel): """User - User feedback evaluation""" - + score: confloat(ge=-5, le=5) = Field( + ..., + description="Sentiment score from -5 (negative) to +5 (positive)" + ) evaluation: UserFeedbackSentiment From 4a5d5f70d040fb4b6c214f0ff2da50b248f0ebfb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:50:21 +0200 Subject: [PATCH 15/17] feat: adds feedback weights to edges --- .../databases/graph/kuzu/adapter.py | 38 +++++++++++++++++++ .../databases/graph/neo4j_driver/adapter.py | 26 +++++++++++++ .../retrieval/graph_completion_retriever.py | 2 + cognee/modules/retrieval/user_qa_feedback.py | 8 ++++ 4 files changed, 74 insertions(+) diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 12c15fb81..95174ec0e 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -1654,3 +1654,41 @@ class KuzuAdapter(GraphDBInterface): id_list = [row[0] for row in rows] return id_list + + async def apply_feedback_weight( + self, + node_ids: List[str], + weight: float, + ) -> None: + """ + Increment `feedback_weight` inside r.properties JSON for edges where + relationship_name = 'used_graph_element_to_answer'. + + """ + # Step 1: fetch matching edges + query = """ + MATCH (n:Node)-[r:EDGE]->() + WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer' + RETURN r.properties, n.id + """ + results = await self.query(query, {"node_ids": node_ids}) + + # Step 2: update JSON client-side + updates = [] + for props_json, source_id in results: + try: + props = json.loads(props_json) if props_json else {} + except json.JSONDecodeError: + props = {} + + props["feedback_weight"] = props.get("feedback_weight", 0) + weight + updates.append((source_id, json.dumps(props))) + + # Step 3: write back + for node_id, new_props in updates: + update_query = """ + MATCH (n:Node)-[r:EDGE]->() + WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer' + SET r.properties = $props + """ + await self.query(update_query, {"node_id": node_id, "props": new_props}) diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 589848dc9..3ff9cb5be 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1345,3 +1345,29 @@ class Neo4jAdapter(GraphDBInterface): id_list = [row["id"] for row in rows if "id" in row] return id_list + + async def apply_feedback_weight( + self, + node_ids: List[str], + weight: float, + ) -> None: + """ + Increment `feedback_weight` on relationships `:used_graph_element_to_answer` + outgoing from nodes whose `id` is in `node_ids`. + + Args: + node_ids: List of node IDs to match. + weight: Amount to add to `r.feedback_weight` (can be negative). + + Side effects: + Updates relationship property `feedback_weight`, defaulting missing values to 0. + """ + query = """ + MATCH (n)-[r]->() + WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer' + SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight + """ + await self.query( + query, + params={"weight": float(weight), "node_ids": list(node_ids)}, + ) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index c831d8550..fb3cf4885 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -248,6 +248,7 @@ class GraphCompletionRetriever(BaseRetriever): "source_node_id": source_id, "target_node_id": target_id_1, "ontology_valid": False, + "feedback_weight": 0, }, ) ) @@ -262,6 +263,7 @@ class GraphCompletionRetriever(BaseRetriever): "source_node_id": source_id, "target_node_id": target_id_2, "ontology_valid": False, + "feedback_weight": 0, }, ) ) diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py index 45b5b0ec0..55c59518a 100644 --- a/cognee/modules/retrieval/user_qa_feedback.py +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -51,6 +51,7 @@ class UserQAFeedback(BaseFeedback): relationships = [] relationship_name = "gives_feedback_to" + to_node_ids = [] for interaction_id in last_interaction_ids: target_id_1 = feedback_id @@ -70,9 +71,16 @@ class UserQAFeedback(BaseFeedback): }, ) ) + to_node_ids.append(str(target_id_2)) + if len(relationships) > 0: graph_engine = await get_graph_engine() await graph_engine.add_edges(relationships) + await graph_engine.apply_feedback_weight( + node_ids=to_node_ids, + weight=feedback_sentiment.score + ) + return [feedback_text] From f5d8fc6e81a8d972ac4095b18eec64d49dfa4f98 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:50:50 +0200 Subject: [PATCH 16/17] chore: ruff ruff --- .../infrastructure/databases/graph/kuzu/adapter.py | 12 ++++++------ .../databases/graph/neo4j_driver/adapter.py | 6 +++--- cognee/modules/retrieval/user_qa_feedback.py | 5 +---- cognee/modules/retrieval/utils/models.py | 5 +++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 95174ec0e..dfe407a1d 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -1656,15 +1656,15 @@ class KuzuAdapter(GraphDBInterface): return id_list async def apply_feedback_weight( - self, - node_ids: List[str], - weight: float, + self, + node_ids: List[str], + weight: float, ) -> None: """ - Increment `feedback_weight` inside r.properties JSON for edges where - relationship_name = 'used_graph_element_to_answer'. + Increment `feedback_weight` inside r.properties JSON for edges where + relationship_name = 'used_graph_element_to_answer'. - """ + """ # Step 1: fetch matching edges query = """ MATCH (n:Node)-[r:EDGE]->() diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 3ff9cb5be..f36296970 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1347,9 +1347,9 @@ class Neo4jAdapter(GraphDBInterface): return id_list async def apply_feedback_weight( - self, - node_ids: List[str], - weight: float, + self, + node_ids: List[str], + weight: float, ) -> None: """ Increment `feedback_weight` on relationships `:used_graph_element_to_answer` diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py index 55c59518a..f667f785f 100644 --- a/cognee/modules/retrieval/user_qa_feedback.py +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -73,14 +73,11 @@ class UserQAFeedback(BaseFeedback): ) to_node_ids.append(str(target_id_2)) - if len(relationships) > 0: graph_engine = await get_graph_engine() await graph_engine.add_edges(relationships) await graph_engine.apply_feedback_weight( - node_ids=to_node_ids, - weight=feedback_sentiment.score + node_ids=to_node_ids, weight=feedback_sentiment.score ) - return [feedback_text] diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py index a71e881a9..58cea29a4 100644 --- a/cognee/modules/retrieval/utils/models.py +++ b/cognee/modules/retrieval/utils/models.py @@ -4,6 +4,7 @@ from cognee.modules.engine.models.node_set import NodeSet from enum import Enum from pydantic import BaseModel, Field, confloat + class CogneeUserInteraction(DataPoint): """User - Cognee interaction""" @@ -32,8 +33,8 @@ class UserFeedbackSentiment(str, Enum): class UserFeedbackEvaluation(BaseModel): """User - User feedback evaluation""" + score: confloat(ge=-5, le=5) = Field( - ..., - description="Sentiment score from -5 (negative) to +5 (positive)" + ..., description="Sentiment score from -5 (negative) to +5 (positive)" ) evaluation: UserFeedbackSentiment From b8cac4c29f9ad70edf413e977e7febf4b2a71013 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:23:47 +0200 Subject: [PATCH 17/17] feat: adds weight test at the end of test_search_db --- cognee/tests/test_search_db.py | 39 +++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index f31c0076f..e7e11637f 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -200,7 +200,7 @@ async def main(): required_fields_user_interaction = {"question", "answer", "context"} required_fields_feedback = {"feedback", "sentiment"} - for node_id, data in nodes: # nodes = your list + for node_id, data in nodes: if data.get("type") == "CogneeUserInteraction": assert required_fields_user_interaction.issubset(data.keys()), ( f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" @@ -223,6 +223,43 @@ async def main(): f"Node {node_id} has invalid value for '{field}': {value!r}" ) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + await cognee.add(text_1, dataset_name) + + await cognee.add([text], dataset_name) + + await cognee.cognify([dataset_name]) + + await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="Next to which country is Germany located?", + save_interaction=True, + ) + + await cognee.search( + query_type=SearchType.FEEDBACK, + query_text="This was the best answer I've ever seen", + last_k=1, + ) + + await cognee.search( + query_type=SearchType.FEEDBACK, + query_text="Wow the correctness of this answer blows my mind", + last_k=1, + ) + + graph = await graph_engine.get_graph_data() + + edges = graph[1] + + for from_node, to_node, relationship_name, properties in edges: + if relationship_name == "used_graph_element_to_answer": + assert properties["feedback_weight"] >= 6, ( + "Feedback weight calculation is not correct, it should be more then 6." + ) + if __name__ == "__main__": import asyncio