feat: adds user feedback search type

This commit is contained in:
hajdul88 2025-08-18 17:54:49 +02:00
parent 9a46d145bb
commit fc43ac7a01
7 changed files with 160 additions and 6 deletions

View file

@ -20,6 +20,7 @@ async def search(
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list: ) -> list:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -186,6 +187,7 @@ async def search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k,
) )
return filtered_search_results return filtered_search_results

View file

@ -1631,3 +1631,26 @@ class KuzuAdapter(GraphDBInterface):
""" """
result = await self.query(query) result = await self.query(query)
return [record[0] for record in result] if result else [] return [record[0] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row[0] for row in rows]
return id_list

View file

@ -1322,3 +1322,26 @@ class Neo4jAdapter(GraphDBInterface):
""" """
result = await self.query(query) result = await self.query(query)
return [record["n"] for record in result] if result else [] return [record["n"] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row["id"] for row in rows if "id" in row]
return id_list

View file

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any
class BaseFeedback(ABC):
"""Base class for all user feedback operations."""
@abstractmethod
async def add_feedback(self, feedback_text: str) -> Any:
"""Add user feedback to the system."""
pass

View file

@ -0,0 +1,78 @@
from typing import Any, Optional, List
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.engine.models import NodeSet
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
from cognee.tasks.storage import add_data_points
logger = get_logger("CompletionRetriever")
class UserQAFeedback(BaseFeedback):
"""
Interface for handling user feedback queries.
Public methods:
- get_context(query: str) -> str
- get_completion(query: str, context: Optional[Any] = None) -> Any
"""
def __init__(self, last_k: Optional[int] = 1) -> None:
"""Initialize retriever with optional custom prompt paths."""
self.last_k = last_k
async def add_feedback(self, feedback_text: str) -> List[str]:
feedback_sentiment = await LLMGateway.acreate_structured_output(
text_input=feedback_text,
system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification",
response_model=UserFeedbackEvaluation,
)
graph_engine = await get_graph_engine()
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
nodeset_name = "UserQAFeedbacks"
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
cognee_user_feedback = CogneeUserFeedback(
id=feedback_id,
feedback=feedback_text,
sentiment=feedback_sentiment.evaluation.value,
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
relationships = []
relationship_name = "gives_feedback_to"
for interaction_id in last_interaction_ids:
target_id_1 = feedback_id
target_id_2 = UUID(interaction_id)
if target_id_1 and target_id_2:
relationships.append(
(
target_id_1,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": target_id_1,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
return [feedback_text]

View file

@ -3,6 +3,8 @@ import json
import asyncio import asyncio
from uuid import UUID from uuid import UUID
from typing import Callable, List, Optional, Type, Union from typing import Callable, List, Optional, Type, Union
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.context_global_variables import set_database_global_context_variables from cognee.context_global_variables import set_database_global_context_variables
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
@ -38,7 +40,8 @@ async def search(
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
): ):
""" """
@ -58,7 +61,14 @@ async def search(
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search( return await authorized_search(
query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction query_text=query_text,
query_type=query_type,
user=user,
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
top_k=top_k,
save_interaction=save_interaction,
last_k=last_k
) )
query = await log_query(query_text, query_type.value, user.id) query = await log_query(query_text, query_type.value, user.id)
@ -72,6 +82,7 @@ async def search(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k
) )
await log_result( await log_result(
@ -93,7 +104,8 @@ async def specific_search(
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
save_interaction: bool = False, save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
) -> list: ) -> list:
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -133,6 +145,7 @@ async def specific_search(
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
} }
# If the query type is FEELING_LUCKY, select the search type intelligently # If the query type is FEELING_LUCKY, select the search type intelligently
@ -161,6 +174,7 @@ async def authorized_search(
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10, top_k: int = 10,
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list: ) -> list:
""" """
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -174,7 +188,7 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions # Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context( search_results = await specific_search_by_context(
search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction, last_k=last_k
) )
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -190,13 +204,14 @@ async def specific_search_by_context(
system_prompt_path: str, system_prompt_path: str,
top_k: int, top_k: int,
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None,
): ):
""" """
Searches all provided datasets and handles setting up of appropriate database context based on permissions. Searches all provided datasets and handles setting up of appropriate database context based on permissions.
Not to be used outside of active access control mode. Not to be used outside of active access control mode.
""" """
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k): async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k):
# Set database configuration in async context for each dataset user has access for # Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id) await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search( search_results = await specific_search(
@ -206,6 +221,7 @@ async def specific_search_by_context(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k,
) )
return { return {
"search_result": search_results, "search_result": search_results,
@ -217,7 +233,7 @@ async def specific_search_by_context(
tasks = [] tasks = []
for dataset in search_datasets: for dataset in search_datasets:
tasks.append( tasks.append(
_search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k) _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k)
) )
return await asyncio.gather(*tasks) return await asyncio.gather(*tasks)

View file

@ -14,3 +14,4 @@ class SearchType(Enum):
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
FEELING_LUCKY = "FEELING_LUCKY" FEELING_LUCKY = "FEELING_LUCKY"
FEEDBACK = "FEEDBACK"