Compare commits

...
Sign in to create a new pull request.

4 commits

11 changed files with 277 additions and 4 deletions

View file

@ -19,6 +19,7 @@ async def search(
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
) -> list: ) -> list:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -101,6 +102,8 @@ async def search(
node_name: Filter results to specific named entities (for targeted search). node_name: Filter results to specific named entities (for targeted search).
last_k: Defines the number of historical answers to give the feedback to.
Returns: Returns:
list: Search results in format determined by query_type: list: Search results in format determined by query_type:
@ -179,6 +182,7 @@ async def search(
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
last_k=last_k,
) )
return filtered_search_results return filtered_search_results

View file

@ -388,3 +388,14 @@ class GraphDBInterface(ABC):
- node_id (str): Unique identifier of the node for which to retrieve connections. - node_id (str): Unique identifier of the node for which to retrieve connections.
""" """
raise NotImplementedError raise NotImplementedError
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Get the last n userintaraction node ids from the graph
Parameters:
-----------
- limit (str): Last n userintaraction node ids to retrieve.
"""
raise NotImplementedError

View file

@ -1287,3 +1287,28 @@ class Neo4jAdapter(GraphDBInterface):
""" """
result = await self.query(query) result = await self.query(query)
return [record["n"] for record in result] if result else [] return [record["n"] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row["id"] for row in rows if "id" in row]
return id_list

View file

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Callable
class BaseFeedback(ABC):
"""Base class for all user feedback operations."""
@abstractmethod
async def add_feedback(self, feedback_text: str) -> Any:
"""Retrieves context based on the query."""
pass

View file

@ -114,4 +114,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer] return [answer]

View file

@ -122,4 +122,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
) )
if context and triplets:
await self.save_qa(question=query, answer=answer, context=context, triplets=triplets)
return [answer] return [answer]

View file

@ -1,14 +1,20 @@
from typing import Any, Optional, Type, List from typing import Any, Optional, Type, List, Tuple
from collections import Counter from collections import Counter
import string import string
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.tasks.storage import add_data_points
from cognee.modules.engine.models.node_set import NodeSet
logger = get_logger("GraphCompletionRetriever") logger = get_logger("GraphCompletionRetriever")
@ -82,6 +88,88 @@ class GraphCompletionRetriever(BaseRetriever):
) )
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
def extract_uuid_from_node(self, node: Any) -> Optional[UUID]:
"""
Try to pull a UUID string out of node.id or node.properties['id'],
then return a UUID instance (or None if neither exists).
"""
id_str = None
if not id_str:
id_str = getattr(node, "id", None)
if hasattr(node, "attributes") and not id_str:
id_str = node.attributes.get("id", None)
id = UUID(id_str) if isinstance(id_str, str) else None
return id
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""
Saves a question and answer pair for later analysis or storage.
Parameters:
-----------
- question (str): The question text.
- answer (str): The answer text.
- context (str): The context text.
- triplets (List): A list of triples retrieved from the graph.
"""
nodeset_name = "Interactions"
interactions_node_set = NodeSet(
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
)
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
cognee_user_interaction = CogneeUserInteraction(
id=source_id,
question=question,
answer=answer,
context=context,
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction])
relationships = []
relationship_name = "used_graph_element_to_answer"
for triplet in triplets:
target_id_1 = self.extract_uuid_from_node(triplet.node1)
target_id_2 = self.extract_uuid_from_node(triplet.node2)
if target_id_1 and target_id_2:
# Defined qa node to triplet node 1
relationships.append(
(
source_id,
target_id_1,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_1,
"ontology_valid": False,
},
)
)
# Defined qa node to triplet node 2
relationships.append(
(
source_id,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
async def get_triplets(self, query: str) -> list: async def get_triplets(self, query: str) -> list:
""" """
Retrieves relevant graph triplets based on a query string. Retrieves relevant graph triplets based on a query string.
@ -118,7 +206,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets return found_triplets
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> Tuple[str, List]:
""" """
Retrieves and resolves graph triplets into context based on a query. Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +227,9 @@ class GraphCompletionRetriever(BaseRetriever):
logger.warning("Empty context was provided to the completion") logger.warning("Empty context was provided to the completion")
return "" return ""
return await self.resolve_edges_to_text(triplets) context = await self.resolve_edges_to_text(triplets)
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """
@ -157,8 +247,10 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided. - Any: A generated completion based on the query and context provided.
""" """
triplets = None
if context is None: if context is None:
context = await self.get_context(query) context, triplets = await self.get_context(query)
completion = await generate_completion( completion = await generate_completion(
query=query, query=query,
@ -166,6 +258,12 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
if context and triplets:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion] return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):

View file

@ -0,0 +1,79 @@
from typing import Any, Optional
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.engine.models import NodeSet
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
from cognee.tasks.storage import add_data_points
logger = get_logger("CompletionRetriever")
class UserQAFeedback(BaseFeedback):
"""
Interface for handling user feedback queries.
Public methods:
- get_context(query: str) -> str
- get_completion(query: str, context: Optional[Any] = None) -> Any
"""
def __init__(self, last_k: Optional[int] = 5):
"""Initialize retriever with optional custom prompt paths."""
self.last_k = last_k
async def add_feedback(self, feedback_text: str) -> Any:
llm_client = get_llm_client()
feedback_sentiment = await llm_client.acreate_structured_output(
feedback_text,
"You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification",
UserFeedbackEvaluation,
)
graph_engine = await get_graph_engine()
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
nodeset_name = "UserQAFeedbacks"
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
cognee_user_feedback = CogneeUserFeedback(
id=feedback_id,
feedback=feedback_text,
sentiment=feedback_sentiment.evaluation.value,
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback])
relationships = []
relationship_name = "gives_feedback_to"
for interaction_id in last_interaction_ids:
target_id_1 = feedback_id
target_id_2 = UUID(interaction_id)
if target_id_1 and target_id_2:
relationships.append(
(
target_id_1,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": target_id_1,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
return [feedback_text]

View file

@ -0,0 +1,32 @@
from typing import Optional
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
from enum import Enum
from pydantic import BaseModel, ValidationError
class CogneeUserInteraction(DataPoint):
"""User - Cognee interaction"""
question: str
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None
class CogneeUserFeedback(DataPoint):
"""User - Cognee Feedback"""
feedback: str
sentiment: str
belongs_to_set: Optional[NodeSet] = None
class UserFeedbackSentiment(str, Enum):
positive = "positive"
negative = "negative"
neutral = "neutral"
class UserFeedbackEvaluation(BaseModel):
evaluation: UserFeedbackSentiment

View file

@ -21,6 +21,7 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User from cognee.modules.users.models import User
@ -39,6 +40,7 @@ async def search(
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
): ):
""" """
@ -49,6 +51,7 @@ async def search(
user: user:
system_prompt_path: system_prompt_path:
top_k: top_k:
last_k:
Returns: Returns:
@ -71,6 +74,7 @@ async def search(
top_k=top_k, top_k=top_k,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
last_k=last_k,
) )
await log_result( await log_result(
@ -92,6 +96,7 @@ async def specific_search(
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None, node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
) -> list: ) -> list:
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -127,6 +132,7 @@ async def specific_search(
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
} }
search_task = search_tasks.get(query_type) search_task = search_tasks.get(query_type)

View file

@ -13,3 +13,4 @@ class SearchType(Enum):
NATURAL_LANGUAGE = "NATURAL_LANGUAGE" NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
FEEDBACK = "FEEDBACK"