feat: adds user feedback feature to retrievers and neo4j

This commit is contained in:
hajdul88 2025-07-30 16:24:14 +02:00
parent 9d46453a7a
commit 65fbfe61c1
8 changed files with 136 additions and 0 deletions

View file

@ -19,6 +19,7 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -101,6 +102,8 @@ async def search(
node_name: Filter results to specific named entities (for targeted search).
last_k: Defines the number of historical answers to give the feedback to.
Returns:
list: Search results in format determined by query_type:
@ -179,6 +182,7 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
last_k=last_k,
)
return filtered_search_results

View file

@ -388,3 +388,14 @@ class GraphDBInterface(ABC):
- node_id (str): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Get the last n userintaraction node ids from the graph
Parameters:
-----------
- limit (str): Last n userintaraction node ids to retrieve.
"""
raise NotImplementedError

View file

@ -1287,3 +1287,28 @@ class Neo4jAdapter(GraphDBInterface):
"""
result = await self.query(query)
return [record["n"] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row["id"] for row in rows if "id" in row]
return id_list

View file

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Callable
class BaseFeedback(ABC):
"""Base class for all user feedback operations."""
@abstractmethod
async def add_feedback(self, feedback_text: str) -> Any:
"""Retrieves context based on the query."""
pass

View file

@ -0,0 +1,71 @@
from typing import Any, Optional
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models import NodeSet
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.tasks.storage import add_data_points
logger = get_logger("CompletionRetriever")
class UserQAFeedback(BaseFeedback):
"""
Interface for handling user feedback queries.
Public methods:
- get_context(query: str) -> str
- get_completion(query: str, context: Optional[Any] = None) -> Any
"""
def __init__(self, last_k: Optional[int] = 5):
"""Initialize retriever with optional custom prompt paths."""
self.last_k = last_k
async def add_feedback(self, feedback_text: str) -> Any:
graph_engine = await get_graph_engine()
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
print()
nodeset_name = "UserQAFeedbacks"
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
cognee_user_feedback = CogneeUserFeedback(
id=feedback_id,
feedback=feedback_text,
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback])
relationships = []
relationship_name = "gives_feedback_to"
for interaction_id in last_interaction_ids:
target_id_1 = feedback_id
target_id_2 = UUID(interaction_id)
if target_id_1 and target_id_2:
relationships.append(
(
target_id_1,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": target_id_1,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
return [feedback_text]

View file

@ -10,3 +10,10 @@ class CogneeUserInteraction(DataPoint):
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None
class CogneeUserFeedback(DataPoint):
"""User - Cognee Feedback"""
feedback: str
belongs_to_set: Optional[NodeSet] = None

View file

@ -21,6 +21,7 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User
@ -39,6 +40,7 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
):
"""
@ -49,6 +51,7 @@ async def search(
user:
system_prompt_path:
top_k:
last_k:
Returns:
@ -71,6 +74,7 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
last_k=last_k,
)
await log_result(
@ -92,6 +96,7 @@ async def specific_search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -127,6 +132,7 @@ async def specific_search(
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
}
search_task = search_tasks.get(query_type)

View file

@ -13,3 +13,4 @@ class SearchType(Enum):
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
FEEDBACK = "FEEDBACK"