feat: adds user feedback search type
This commit is contained in:
parent
9a46d145bb
commit
fc43ac7a01
7 changed files with 160 additions and 6 deletions
|
|
@ -20,6 +20,7 @@ async def search(
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Search and query the knowledge graph for insights, information, and connections.
|
Search and query the knowledge graph for insights, information, and connections.
|
||||||
|
|
@ -186,6 +187,7 @@ async def search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
return filtered_search_results
|
return filtered_search_results
|
||||||
|
|
|
||||||
|
|
@ -1631,3 +1631,26 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
result = await self.query(query)
|
result = await self.query(query)
|
||||||
return [record[0] for record in result] if result else []
|
return [record[0] for record in result] if result else []
|
||||||
|
|
||||||
|
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
||||||
|
"""
|
||||||
|
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- limit (int): The maximum number of interaction IDs to return.
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.type = 'CogneeUserInteraction'
|
||||||
|
RETURN n.id as id
|
||||||
|
ORDER BY n.created_at DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
rows = await self.query(query, {"limit": limit})
|
||||||
|
|
||||||
|
id_list = [row[0] for row in rows]
|
||||||
|
return id_list
|
||||||
|
|
|
||||||
|
|
@ -1322,3 +1322,26 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
result = await self.query(query)
|
result = await self.query(query)
|
||||||
return [record["n"] for record in result] if result else []
|
return [record["n"] for record in result] if result else []
|
||||||
|
|
||||||
|
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
||||||
|
"""
|
||||||
|
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- limit (int): The maximum number of interaction IDs to return.
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.type = 'CogneeUserInteraction'
|
||||||
|
RETURN n.id as id
|
||||||
|
ORDER BY n.created_at DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
rows = await self.query(query, {"limit": limit})
|
||||||
|
|
||||||
|
id_list = [row["id"] for row in rows if "id" in row]
|
||||||
|
return id_list
|
||||||
|
|
|
||||||
11
cognee/modules/retrieval/base_feedback.py
Normal file
11
cognee/modules/retrieval/base_feedback.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFeedback(ABC):
|
||||||
|
"""Base class for all user feedback operations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_feedback(self, feedback_text: str) -> Any:
|
||||||
|
"""Add user feedback to the system."""
|
||||||
|
pass
|
||||||
78
cognee/modules/retrieval/user_qa_feedback.py
Normal file
78
cognee/modules/retrieval/user_qa_feedback.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
from uuid import NAMESPACE_OID, uuid5, UUID
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.modules.engine.models import NodeSet
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
||||||
|
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
||||||
|
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
|
logger = get_logger("CompletionRetriever")
|
||||||
|
|
||||||
|
|
||||||
|
class UserQAFeedback(BaseFeedback):
|
||||||
|
"""
|
||||||
|
Interface for handling user feedback queries.
|
||||||
|
Public methods:
|
||||||
|
- get_context(query: str) -> str
|
||||||
|
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, last_k: Optional[int] = 1) -> None:
|
||||||
|
"""Initialize retriever with optional custom prompt paths."""
|
||||||
|
self.last_k = last_k
|
||||||
|
|
||||||
|
async def add_feedback(self, feedback_text: str) -> List[str]:
|
||||||
|
|
||||||
|
feedback_sentiment = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=feedback_text,
|
||||||
|
system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification",
|
||||||
|
response_model=UserFeedbackEvaluation,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
|
||||||
|
|
||||||
|
nodeset_name = "UserQAFeedbacks"
|
||||||
|
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
|
||||||
|
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
|
||||||
|
|
||||||
|
cognee_user_feedback = CogneeUserFeedback(
|
||||||
|
id=feedback_id,
|
||||||
|
feedback=feedback_text,
|
||||||
|
sentiment=feedback_sentiment.evaluation.value,
|
||||||
|
belongs_to_set=feedbacks_node_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
|
||||||
|
|
||||||
|
relationships = []
|
||||||
|
relationship_name = "gives_feedback_to"
|
||||||
|
|
||||||
|
for interaction_id in last_interaction_ids:
|
||||||
|
target_id_1 = feedback_id
|
||||||
|
target_id_2 = UUID(interaction_id)
|
||||||
|
|
||||||
|
if target_id_1 and target_id_2:
|
||||||
|
relationships.append(
|
||||||
|
(
|
||||||
|
target_id_1,
|
||||||
|
target_id_2,
|
||||||
|
relationship_name,
|
||||||
|
{
|
||||||
|
"relationship_name": relationship_name,
|
||||||
|
"source_node_id": target_id_1,
|
||||||
|
"target_node_id": target_id_2,
|
||||||
|
"ontology_valid": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(relationships) > 0:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
await graph_engine.add_edges(relationships)
|
||||||
|
|
||||||
|
return [feedback_text]
|
||||||
|
|
@ -3,6 +3,8 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Callable, List, Optional, Type, Union
|
from typing import Callable, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
|
@ -38,7 +40,8 @@ async def search(
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: Optional[bool] = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -58,7 +61,14 @@ async def search(
|
||||||
# Use search function filtered by permissions if access control is enabled
|
# Use search function filtered by permissions if access control is enabled
|
||||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
return await authorized_search(
|
return await authorized_search(
|
||||||
query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction
|
query_text=query_text,
|
||||||
|
query_type=query_type,
|
||||||
|
user=user,
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k
|
||||||
)
|
)
|
||||||
|
|
||||||
query = await log_query(query_text, query_type.value, user.id)
|
query = await log_query(query_text, query_type.value, user.id)
|
||||||
|
|
@ -72,6 +82,7 @@ async def search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_result(
|
await log_result(
|
||||||
|
|
@ -93,7 +104,8 @@ async def specific_search(
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: Optional[bool] = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
search_tasks: dict[SearchType, Callable] = {
|
search_tasks: dict[SearchType, Callable] = {
|
||||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||||
|
|
@ -133,6 +145,7 @@ async def specific_search(
|
||||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||||
|
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
|
||||||
}
|
}
|
||||||
|
|
||||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||||
|
|
@ -161,6 +174,7 @@ async def authorized_search(
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||||
|
|
@ -174,7 +188,7 @@ async def authorized_search(
|
||||||
|
|
||||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||||
search_results = await specific_search_by_context(
|
search_results = await specific_search_by_context(
|
||||||
search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction
|
search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction, last_k=last_k
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
|
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
|
||||||
|
|
@ -190,13 +204,14 @@ async def specific_search_by_context(
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||||
Not to be used outside of active access control mode.
|
Not to be used outside of active access control mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
|
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k):
|
||||||
# Set database configuration in async context for each dataset user has access for
|
# Set database configuration in async context for each dataset user has access for
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
search_results = await specific_search(
|
search_results = await specific_search(
|
||||||
|
|
@ -206,6 +221,7 @@ async def specific_search_by_context(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"search_result": search_results,
|
"search_result": search_results,
|
||||||
|
|
@ -217,7 +233,7 @@ async def specific_search_by_context(
|
||||||
tasks = []
|
tasks = []
|
||||||
for dataset in search_datasets:
|
for dataset in search_datasets:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k)
|
_search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k, last_k)
|
||||||
)
|
)
|
||||||
|
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
|
||||||
|
|
@ -14,3 +14,4 @@ class SearchType(Enum):
|
||||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||||
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||||
FEELING_LUCKY = "FEELING_LUCKY"
|
FEELING_LUCKY = "FEELING_LUCKY"
|
||||||
|
FEEDBACK = "FEEDBACK"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue