From fc43ac7a015226767bfb6d33d8b639435659565f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 18 Aug 2025 17:54:49 +0200 Subject: [PATCH] feat: adds user feedback search type --- cognee/api/v1/search/search.py | 2 + .../databases/graph/kuzu/adapter.py | 23 ++++++ .../databases/graph/neo4j_driver/adapter.py | 23 ++++++ cognee/modules/retrieval/base_feedback.py | 11 +++ cognee/modules/retrieval/user_qa_feedback.py | 78 +++++++++++++++++++ cognee/modules/search/methods/search.py | 28 +++++-- cognee/modules/search/types/SearchType.py | 1 + 7 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 cognee/modules/retrieval/base_feedback.py create mode 100644 cognee/modules/retrieval/user_qa_feedback.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index b4499192d..f37f8ba6d 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -20,6 +20,7 @@ async def search( node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + last_k: Optional[int] = None, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -186,6 +187,7 @@ async def search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + last_k=last_k, ) return filtered_search_results diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 1bafb3754..12c15fb81 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -1631,3 +1631,26 @@ class KuzuAdapter(GraphDBInterface): """ result = await self.query(query) return [record[0] for record in result] if result else [] + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Retrieve the IDs of the most recent CogneeUserInteraction nodes. + Parameters: + ----------- + - limit (int): The maximum number of interaction IDs to return. + Returns: + -------- + - List[str]: A list of interaction IDs, sorted by created_at descending. + """ + + query = """ + MATCH (n) + WHERE n.type = 'CogneeUserInteraction' + RETURN n.id as id + ORDER BY n.created_at DESC + LIMIT $limit + """ + rows = await self.query(query, {"limit": limit}) + + id_list = [row[0] for row in rows] + return id_list diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index ea8072554..589848dc9 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1322,3 +1322,26 @@ class Neo4jAdapter(GraphDBInterface): """ result = await self.query(query) return [record["n"] for record in result] if result else [] + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Retrieve the IDs of the most recent CogneeUserInteraction nodes. + Parameters: + ----------- + - limit (int): The maximum number of interaction IDs to return. + Returns: + -------- + - List[str]: A list of interaction IDs, sorted by created_at descending. + """ + + query = """ + MATCH (n) + WHERE n.type = 'CogneeUserInteraction' + RETURN n.id as id + ORDER BY n.created_at DESC + LIMIT $limit + """ + rows = await self.query(query, {"limit": limit}) + + id_list = [row["id"] for row in rows if "id" in row] + return id_list diff --git a/cognee/modules/retrieval/base_feedback.py b/cognee/modules/retrieval/base_feedback.py new file mode 100644 index 000000000..62ad443ee --- /dev/null +++ b/cognee/modules/retrieval/base_feedback.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseFeedback(ABC): + """Base class for all user feedback operations.""" + + @abstractmethod + async def add_feedback(self, feedback_text: str) -> Any: + """Add user feedback to the system.""" + pass \ No newline at end of file diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py new file mode 100644 index 000000000..39f8c25f5 --- /dev/null +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -0,0 +1,78 @@ +from typing import Any, Optional, List + +from uuid import NAMESPACE_OID, uuid5, UUID +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.engine.models import NodeSet +from cognee.shared.logging_utils import get_logger +from cognee.modules.retrieval.base_feedback import BaseFeedback +from cognee.modules.retrieval.utils.models import CogneeUserFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation +from cognee.tasks.storage import add_data_points + +logger = get_logger("CompletionRetriever") + + +class UserQAFeedback(BaseFeedback): + """ + Interface for handling user feedback queries. + Public methods: + - get_context(query: str) -> str + - get_completion(query: str, context: Optional[Any] = None) -> Any + """ + + def __init__(self, last_k: Optional[int] = 1) -> None: + """Initialize retriever with optional custom prompt paths.""" + self.last_k = last_k + + async def add_feedback(self, feedback_text: str) -> List[str]: + + feedback_sentiment = await LLMGateway.acreate_structured_output( + text_input=feedback_text, + system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification", + response_model=UserFeedbackEvaluation, + ) + + graph_engine = await get_graph_engine() + last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k) + + nodeset_name = "UserQAFeedbacks" + feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name) + feedback_id = uuid5(NAMESPACE_OID, name=feedback_text) + + cognee_user_feedback = CogneeUserFeedback( + id=feedback_id, + feedback=feedback_text, + sentiment=feedback_sentiment.evaluation.value, + belongs_to_set=feedbacks_node_set, + ) + + await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False) + + relationships = [] + relationship_name = "gives_feedback_to" + + for interaction_id in last_interaction_ids: + target_id_1 = feedback_id + target_id_2 = UUID(interaction_id) + + if target_id_1 and target_id_2: + relationships.append( + ( + target_id_1, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": target_id_1, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) + + return [feedback_text] \ No newline at end of file diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index ba11d7f72..5f5371af7 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -3,6 +3,8 @@ import json import asyncio from uuid import UUID from typing import Callable, List, Optional, Type, Union + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.context_global_variables import set_database_global_context_variables from cognee.modules.retrieval.chunks_retriever import ChunksRetriever @@ -38,7 +40,8 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, ): """ @@ -58,7 +61,14 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction + query_text=query_text, + query_type=query_type, + user=user, + dataset_ids=dataset_ids, + system_prompt_path=system_prompt_path, + top_k=top_k, + save_interaction=save_interaction, + last_k=last_k ) query = await log_query(query_text, query_type.value, user.id) @@ -72,6 +82,7 @@ async def search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + last_k=last_k ) await log_result( @@ -93,7 +104,8 @@ async def specific_search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -133,6 +145,7 @@ async def specific_search( SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, + SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback, } # If the query type is FEELING_LUCKY, select the search type intelligently @@ -161,6 +174,7 @@ async def authorized_search( system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, save_interaction: bool = False, + last_k: Optional[int] = None, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -174,7 +188,7 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction + search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction, last_k=last_k ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -190,13 +204,14 @@ async def specific_search_by_context( system_prompt_path: str, top_k: int, save_interaction: bool = False, + last_k: Optional[int] = None, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. Not to be used outside of active access control mode. """ - async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k): + async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( @@ -206,6 +221,7 @@ async def specific_search_by_context( system_prompt_path=system_prompt_path, top_k=top_k, save_interaction=save_interaction, + last_k=last_k, ) return { "search_result": search_results, @@ -217,7 +233,7 @@ async def specific_search_by_context( tasks = [] for dataset in search_datasets: tasks.append( - _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k) + _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k) ) return await asyncio.gather(*tasks) diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 8248117e7..c1f0521b2 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -14,3 +14,4 @@ class SearchType(Enum): GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" FEELING_LUCKY = "FEELING_LUCKY" + FEEDBACK = "FEEDBACK"