Compare commits
4 commits
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e103a7488 | ||
|
|
65fbfe61c1 | ||
|
|
9d46453a7a | ||
|
|
e0221feb74 |
11 changed files with 277 additions and 4 deletions
|
|
@ -19,6 +19,7 @@ async def search(
|
|||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
last_k: Optional[int] = None,
|
||||
) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -101,6 +102,8 @@ async def search(
|
|||
|
||||
node_name: Filter results to specific named entities (for targeted search).
|
||||
|
||||
last_k: Defines the number of historical answers to give the feedback to.
|
||||
|
||||
Returns:
|
||||
list: Search results in format determined by query_type:
|
||||
|
||||
|
|
@ -179,6 +182,7 @@ async def search(
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -388,3 +388,14 @@ class GraphDBInterface(ABC):
|
|||
- node_id (str): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
||||
"""
|
||||
Get the last n userintaraction node ids from the graph
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- limit (str): Last n userintaraction node ids to retrieve.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1287,3 +1287,28 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
result = await self.query(query)
|
||||
return [record["n"] for record in result] if result else []
|
||||
|
||||
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
||||
"""
|
||||
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- limit (int): The maximum number of interaction IDs to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
||||
"""
|
||||
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.type = 'CogneeUserInteraction'
|
||||
RETURN n.id as id
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
rows = await self.query(query, {"limit": limit})
|
||||
|
||||
id_list = [row["id"] for row in rows if "id" in row]
|
||||
return id_list
|
||||
|
|
|
|||
11
cognee/modules/retrieval/base_feedback.py
Normal file
11
cognee/modules/retrieval/base_feedback.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Callable
|
||||
|
||||
|
||||
class BaseFeedback(ABC):
|
||||
"""Base class for all user feedback operations."""
|
||||
|
||||
@abstractmethod
|
||||
async def add_feedback(self, feedback_text: str) -> Any:
|
||||
"""Retrieves context based on the query."""
|
||||
pass
|
||||
|
|
@ -114,4 +114,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
if context and triplets:
|
||||
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
|
||||
|
||||
return [answer]
|
||||
|
|
|
|||
|
|
@ -122,4 +122,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||
)
|
||||
|
||||
if context and triplets:
|
||||
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
|
||||
|
||||
return [answer]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,20 @@
|
|||
from typing import Any, Optional, Type, List
|
||||
from typing import Any, Optional, Type, List, Tuple
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
from uuid import NAMESPACE_OID, uuid5, UUID
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
|
||||
logger = get_logger("GraphCompletionRetriever")
|
||||
|
||||
|
|
@ -82,6 +88,88 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
)
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
||||
def extract_uuid_from_node(self, node: Any) -> Optional[UUID]:
|
||||
"""
|
||||
Try to pull a UUID string out of node.id or node.properties['id'],
|
||||
then return a UUID instance (or None if neither exists).
|
||||
"""
|
||||
id_str = None
|
||||
if not id_str:
|
||||
id_str = getattr(node, "id", None)
|
||||
|
||||
if hasattr(node, "attributes") and not id_str:
|
||||
id_str = node.attributes.get("id", None)
|
||||
|
||||
id = UUID(id_str) if isinstance(id_str, str) else None
|
||||
return id
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
Saves a question and answer pair for later analysis or storage.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- question (str): The question text.
|
||||
- answer (str): The answer text.
|
||||
- context (str): The context text.
|
||||
- triplets (List): A list of triples retrieved from the graph.
|
||||
"""
|
||||
nodeset_name = "Interactions"
|
||||
interactions_node_set = NodeSet(
|
||||
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
|
||||
)
|
||||
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
|
||||
|
||||
cognee_user_interaction = CogneeUserInteraction(
|
||||
id=source_id,
|
||||
question=question,
|
||||
answer=answer,
|
||||
context=context,
|
||||
belongs_to_set=interactions_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_interaction])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "used_graph_element_to_answer"
|
||||
for triplet in triplets:
|
||||
target_id_1 = self.extract_uuid_from_node(triplet.node1)
|
||||
target_id_2 = self.extract_uuid_from_node(triplet.node2)
|
||||
if target_id_1 and target_id_2:
|
||||
# Defined qa node to triplet node 1
|
||||
relationships.append(
|
||||
(
|
||||
source_id,
|
||||
target_id_1,
|
||||
relationship_name,
|
||||
{
|
||||
"relationship_name": relationship_name,
|
||||
"source_node_id": source_id,
|
||||
"target_node_id": target_id_1,
|
||||
"ontology_valid": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Defined qa node to triplet node 2
|
||||
relationships.append(
|
||||
(
|
||||
source_id,
|
||||
target_id_2,
|
||||
relationship_name,
|
||||
{
|
||||
"relationship_name": relationship_name,
|
||||
"source_node_id": source_id,
|
||||
"target_node_id": target_id_2,
|
||||
"ontology_valid": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if len(relationships) > 0:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
"""
|
||||
Retrieves relevant graph triplets based on a query string.
|
||||
|
|
@ -118,7 +206,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
async def get_context(self, query: str) -> Tuple[str, List]:
|
||||
"""
|
||||
Retrieves and resolves graph triplets into context based on a query.
|
||||
|
||||
|
|
@ -139,7 +227,9 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
logger.warning("Empty context was provided to the completion")
|
||||
return ""
|
||||
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
return context, triplets
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
@ -157,8 +247,10 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
- Any: A generated completion based on the query and context provided.
|
||||
"""
|
||||
triplets = None
|
||||
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
context, triplets = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
|
|
@ -166,6 +258,12 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
if context and triplets:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context, triplets=triplets
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
||||
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
||||
|
|
|
|||
79
cognee/modules/retrieval/user_qa_feedback.py
Normal file
79
cognee/modules/retrieval/user_qa_feedback.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from uuid import NAMESPACE_OID, uuid5, UUID
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
||||
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
||||
class UserQAFeedback(BaseFeedback):
|
||||
"""
|
||||
Interface for handling user feedback queries.
|
||||
|
||||
Public methods:
|
||||
- get_context(query: str) -> str
|
||||
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
||||
"""
|
||||
|
||||
def __init__(self, last_k: Optional[int] = 5):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.last_k = last_k
|
||||
|
||||
async def add_feedback(self, feedback_text: str) -> Any:
|
||||
llm_client = get_llm_client()
|
||||
feedback_sentiment = await llm_client.acreate_structured_output(
|
||||
feedback_text,
|
||||
"You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification",
|
||||
UserFeedbackEvaluation,
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
|
||||
|
||||
nodeset_name = "UserQAFeedbacks"
|
||||
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
|
||||
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
|
||||
|
||||
cognee_user_feedback = CogneeUserFeedback(
|
||||
id=feedback_id,
|
||||
feedback=feedback_text,
|
||||
sentiment=feedback_sentiment.evaluation.value,
|
||||
belongs_to_set=feedbacks_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_feedback])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "gives_feedback_to"
|
||||
|
||||
for interaction_id in last_interaction_ids:
|
||||
target_id_1 = feedback_id
|
||||
target_id_2 = UUID(interaction_id)
|
||||
|
||||
if target_id_1 and target_id_2:
|
||||
relationships.append(
|
||||
(
|
||||
target_id_1,
|
||||
target_id_2,
|
||||
relationship_name,
|
||||
{
|
||||
"relationship_name": relationship_name,
|
||||
"source_node_id": target_id_1,
|
||||
"target_node_id": target_id_2,
|
||||
"ontology_valid": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if len(relationships) > 0:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
|
||||
return [feedback_text]
|
||||
32
cognee/modules/retrieval/utils/models.py
Normal file
32
cognee/modules/retrieval/utils/models.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Optional
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
class CogneeUserInteraction(DataPoint):
|
||||
"""User - Cognee interaction"""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
context: str
|
||||
belongs_to_set: Optional[NodeSet] = None
|
||||
|
||||
|
||||
class CogneeUserFeedback(DataPoint):
|
||||
"""User - Cognee Feedback"""
|
||||
|
||||
feedback: str
|
||||
sentiment: str
|
||||
belongs_to_set: Optional[NodeSet] = None
|
||||
|
||||
|
||||
class UserFeedbackSentiment(str, Enum):
|
||||
positive = "positive"
|
||||
negative = "negative"
|
||||
neutral = "neutral"
|
||||
|
||||
|
||||
class UserFeedbackEvaluation(BaseModel):
|
||||
evaluation: UserFeedbackSentiment
|
||||
|
|
@ -21,6 +21,7 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
|
|||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -39,6 +40,7 @@ async def search(
|
|||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
last_k: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
|
|
@ -49,6 +51,7 @@ async def search(
|
|||
user:
|
||||
system_prompt_path:
|
||||
top_k:
|
||||
last_k:
|
||||
|
||||
Returns:
|
||||
|
||||
|
|
@ -71,6 +74,7 @@ async def search(
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
||||
await log_result(
|
||||
|
|
@ -92,6 +96,7 @@ async def specific_search(
|
|||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
last_k: Optional[int] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
|
|
@ -127,6 +132,7 @@ async def specific_search(
|
|||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
|
||||
}
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,4 @@ class SearchType(Enum):
|
|||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||
FEEDBACK = "FEEDBACK"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue