diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index eb245f545..61c3861b6 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -19,6 +19,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + last_k: Optional[int] = None, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -101,6 +102,8 @@ async def search( node_name: Filter results to specific named entities (for targeted search). + last_k: Defines the number of historical answers to give the feedback to. + Returns: list: Search results in format determined by query_type: @@ -179,6 +182,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + last_k=last_k, ) return filtered_search_results diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 16600b386..973568cea 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -388,3 +388,14 @@ class GraphDBInterface(ABC): - node_id (str): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Get the last n userintaraction node ids from the graph + + Parameters: + ----------- + + - limit (str): Last n userintaraction node ids to retrieve. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index b23bf8e00..f9b89559e 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1287,3 +1287,28 @@ class Neo4jAdapter(GraphDBInterface): """ result = await self.query(query) return [record["n"] for record in result] if result else [] + + async def get_last_user_interaction_ids(self, limit: int) -> List[str]: + """ + Retrieve the IDs of the most recent CogneeUserInteraction nodes. + + Parameters: + ----------- + - limit (int): The maximum number of interaction IDs to return. + + Returns: + -------- + - List[str]: A list of interaction IDs, sorted by created_at descending. + """ + + query = """ + MATCH (n) + WHERE n.type = 'CogneeUserInteraction' + RETURN n.id as id + ORDER BY n.created_at DESC + LIMIT $limit + """ + rows = await self.query(query, {"limit": limit}) + + id_list = [row["id"] for row in rows if "id" in row] + return id_list diff --git a/cognee/modules/retrieval/base_feedback.py b/cognee/modules/retrieval/base_feedback.py new file mode 100644 index 000000000..932dfbf9b --- /dev/null +++ b/cognee/modules/retrieval/base_feedback.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Callable + + +class BaseFeedback(ABC): + """Base class for all user feedback operations.""" + + @abstractmethod + async def add_feedback(self, feedback_text: str) -> Any: + """Retrieves context based on the query.""" + pass diff --git a/cognee/modules/retrieval/user_qa_feedback.py b/cognee/modules/retrieval/user_qa_feedback.py new file mode 100644 index 000000000..5897d80ab --- /dev/null +++ b/cognee/modules/retrieval/user_qa_feedback.py @@ -0,0 +1,71 @@ +from typing import Any, Optional + +from uuid import NAMESPACE_OID, uuid5, UUID +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.modules.engine.models import NodeSet +from cognee.shared.logging_utils import get_logger +from cognee.modules.retrieval.base_feedback import BaseFeedback +from cognee.modules.retrieval.utils.models import CogneeUserFeedback +from cognee.tasks.storage import add_data_points + +logger = get_logger("CompletionRetriever") + + +class UserQAFeedback(BaseFeedback): + """ + Interface for handling user feedback queries. + + Public methods: + - get_context(query: str) -> str + - get_completion(query: str, context: Optional[Any] = None) -> Any + """ + + def __init__(self, last_k: Optional[int] = 5): + """Initialize retriever with optional custom prompt paths.""" + self.last_k = last_k + + async def add_feedback(self, feedback_text: str) -> Any: + graph_engine = await get_graph_engine() + + last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k) + print() + + nodeset_name = "UserQAFeedbacks" + feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name) + feedback_id = uuid5(NAMESPACE_OID, name=feedback_text) + + cognee_user_feedback = CogneeUserFeedback( + id=feedback_id, + feedback=feedback_text, + belongs_to_set=feedbacks_node_set, + ) + + await add_data_points(data_points=[cognee_user_feedback]) + + relationships = [] + relationship_name = "gives_feedback_to" + + for interaction_id in last_interaction_ids: + target_id_1 = feedback_id + target_id_2 = UUID(interaction_id) + + if target_id_1 and target_id_2: + relationships.append( + ( + target_id_1, + target_id_2, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": target_id_1, + "target_node_id": target_id_2, + "ontology_valid": False, + }, + ) + ) + + if len(relationships) > 0: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) + + return [feedback_text] diff --git a/cognee/modules/retrieval/utils/models.py b/cognee/modules/retrieval/utils/models.py index 895c280ae..42c49b799 100644 --- a/cognee/modules/retrieval/utils/models.py +++ b/cognee/modules/retrieval/utils/models.py @@ -10,3 +10,10 @@ class CogneeUserInteraction(DataPoint): answer: str context: str belongs_to_set: Optional[NodeSet] = None + + +class CogneeUserFeedback(DataPoint): + """User - Cognee Feedback""" + + feedback: str + belongs_to_set: Optional[NodeSet] = None diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 1eff23c4a..83c5b5a3a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -21,6 +21,7 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.types import SearchType from cognee.modules.storage.utils import JSONEncoder from cognee.modules.users.models import User @@ -39,6 +40,7 @@ async def search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + last_k: Optional[int] = None, ): """ @@ -49,6 +51,7 @@ async def search( user: system_prompt_path: top_k: + last_k: Returns: @@ -71,6 +74,7 @@ async def search( top_k=top_k, node_type=node_type, node_name=node_name, + last_k=last_k, ) await log_result( @@ -92,6 +96,7 @@ async def specific_search( top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + last_k: Optional[int] = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -127,6 +132,7 @@ async def specific_search( SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, + SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback, } search_task = search_tasks.get(query_type) diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 1c672f0f0..dbcb4dd1d 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -13,3 +13,4 @@ class SearchType(Enum): NATURAL_LANGUAGE = "NATURAL_LANGUAGE" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" + FEEDBACK = "FEEDBACK"