feature: Introduces Cognee-user interactions feature and feedback search type (#1264)

<!-- .github/pull_request_template.md -->

## Description
Introduces Cognee-user interactions nodeset feature and feedback search
type

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-08-19 18:21:27 +02:00 committed by GitHub
commit 6d9a100b7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 599 additions and 38 deletions

View file

@ -19,6 +19,8 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -107,6 +109,8 @@ async def search(
node_name: Filter results to specific named entities (for targeted search).
save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not
Returns:
list: Search results in format determined by query_type:
@ -182,6 +186,8 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
return filtered_search_results

View file

@ -1632,3 +1632,64 @@ class KuzuAdapter(GraphDBInterface):
"""
result = await self.query(query)
return [record[0] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row[0] for row in rows]
return id_list
async def apply_feedback_weight(
self,
node_ids: List[str],
weight: float,
) -> None:
"""
Increment `feedback_weight` inside r.properties JSON for edges where
relationship_name = 'used_graph_element_to_answer'.
"""
# Step 1: fetch matching edges
query = """
MATCH (n:Node)-[r:EDGE]->()
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
RETURN r.properties, n.id
"""
results = await self.query(query, {"node_ids": node_ids})
# Step 2: update JSON client-side
updates = []
for props_json, source_id in results:
try:
props = json.loads(props_json) if props_json else {}
except json.JSONDecodeError:
props = {}
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
updates.append((source_id, json.dumps(props)))
# Step 3: write back
for node_id, new_props in updates:
update_query = """
MATCH (n:Node)-[r:EDGE]->()
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
SET r.properties = $props
"""
await self.query(update_query, {"node_id": node_id, "props": new_props})

View file

@ -1322,3 +1322,52 @@ class Neo4jAdapter(GraphDBInterface):
"""
result = await self.query(query)
return [record["n"] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row["id"] for row in rows if "id" in row]
return id_list
async def apply_feedback_weight(
self,
node_ids: List[str],
weight: float,
) -> None:
"""
Increment `feedback_weight` on relationships `:used_graph_element_to_answer`
outgoing from nodes whose `id` is in `node_ids`.
Args:
node_ids: List of node IDs to match.
weight: Amount to add to `r.feedback_weight` (can be negative).
Side effects:
Updates relationship property `feedback_weight`, defaulting missing values to 0.
"""
query = """
MATCH (n)-[r]->()
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight
"""
await self.query(
query,
params={"weight": float(weight), "node_ids": list(node_ids)},
)

View file

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any
class BaseFeedback(ABC):
"""Base class for all user feedback operations."""
@abstractmethod
async def add_feedback(self, feedback_text: str) -> Any:
"""Add user feedback to the system."""
pass

View file

@ -29,6 +29,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -36,6 +37,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
async def get_completion(
@ -105,11 +107,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1
answer = await generate_completion(
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [answer]
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]

View file

@ -35,6 +35,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -42,6 +43,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
@ -75,7 +77,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
followup_question = ""
triplets = []
answer = [""]
completion = [""]
for round_idx in range(max_iter + 1):
if round_idx == 0:
@ -85,15 +87,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_triplets(followup_question)
context = await self.resolve_edges_to_text(list(set(triplets)))
answer = await generate_completion(
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": answer, "context": context}
valid_args = {"query": query, "answer": completion, "context": context}
valid_user_prompt = LLMGateway.render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
@ -106,7 +108,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = LLMGateway.render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
@ -121,4 +123,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return [answer]
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]

View file

@ -1,14 +1,20 @@
from typing import Any, Optional, Type, List
from typing import Any, Optional, Type, List, Coroutine
from collections import Counter
from uuid import NAMESPACE_OID, uuid5
import string
from cognee.infrastructure.engine import DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.modules.engine.models.node_set import NodeSet
from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger("GraphCompletionRetriever")
@ -33,8 +39,10 @@ class GraphCompletionRetriever(BaseRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
@ -118,7 +126,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets
async def get_context(self, query: str) -> str:
async def get_context(self, query: str) -> str | tuple[str, list]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -137,9 +145,11 @@ class GraphCompletionRetriever(BaseRetriever):
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
return ""
return "", triplets
return await self.resolve_edges_to_text(triplets)
context = await self.resolve_edges_to_text(triplets)
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
@ -157,8 +167,10 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided.
"""
triplets = None
if context is None:
context = await self.get_context(query)
context, triplets = await self.get_context(query)
completion = await generate_completion(
query=query,
@ -166,6 +178,12 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
@ -187,3 +205,69 @@ class GraphCompletionRetriever(BaseRetriever):
first_n_words = text.split()[:first_n_words]
top_n_words = self._top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""
Saves a question and answer pair for later analysis or storage.
Parameters:
-----------
- question (str): The question text.
- answer (str): The answer text.
- context (str): The context text.
- triplets (List): A list of triples retrieved from the graph.
"""
nodeset_name = "Interactions"
interactions_node_set = NodeSet(
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
)
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
cognee_user_interaction = CogneeUserInteraction(
id=source_id,
question=question,
answer=answer,
context=context,
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
relationships = []
relationship_name = "used_graph_element_to_answer"
for triplet in triplets:
target_id_1 = extract_uuid_from_node(triplet.node1)
target_id_2 = extract_uuid_from_node(triplet.node2)
if target_id_1 and target_id_2:
relationships.append(
(
source_id,
target_id_1,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_1,
"ontology_valid": False,
"feedback_weight": 0,
},
)
)
relationships.append(
(
source_id,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_2,
"ontology_valid": False,
"feedback_weight": 0,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)

View file

@ -24,6 +24,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@ -32,6 +33,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -0,0 +1,83 @@
from typing import Any, Optional, List
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.engine.models import NodeSet
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
from cognee.tasks.storage import add_data_points
logger = get_logger("CompletionRetriever")
class UserQAFeedback(BaseFeedback):
"""
Interface for handling user feedback queries.
Public methods:
- get_context(query: str) -> str
- get_completion(query: str, context: Optional[Any] = None) -> Any
"""
def __init__(self, last_k: Optional[int] = 1) -> None:
"""Initialize retriever with optional custom prompt paths."""
self.last_k = last_k
async def add_feedback(self, feedback_text: str) -> List[str]:
feedback_sentiment = await LLMGateway.acreate_structured_output(
text_input=feedback_text,
system_prompt="You are a sentiment analysis assistant. For each piece of user feedback you receive, return exactly one of: Positive, Negative, or Neutral classification and a corresponding score from -5 (worst negative) to 5 (best positive)",
response_model=UserFeedbackEvaluation,
)
graph_engine = await get_graph_engine()
last_interaction_ids = await graph_engine.get_last_user_interaction_ids(limit=self.last_k)
nodeset_name = "UserQAFeedbacks"
feedbacks_node_set = NodeSet(id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name)
feedback_id = uuid5(NAMESPACE_OID, name=feedback_text)
cognee_user_feedback = CogneeUserFeedback(
id=feedback_id,
feedback=feedback_text,
sentiment=feedback_sentiment.evaluation.value,
score=feedback_sentiment.score,
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
relationships = []
relationship_name = "gives_feedback_to"
to_node_ids = []
for interaction_id in last_interaction_ids:
target_id_1 = feedback_id
target_id_2 = UUID(interaction_id)
if target_id_1 and target_id_2:
relationships.append(
(
target_id_1,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": target_id_1,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
to_node_ids.append(str(target_id_2))
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await graph_engine.apply_feedback_weight(
node_ids=to_node_ids, weight=feedback_sentiment.score
)
return [feedback_text]

View file

@ -0,0 +1,18 @@
from typing import Any, Optional
from uuid import UUID
def extract_uuid_from_node(node: Any) -> Optional[UUID]:
"""
Try to pull a UUID string out of node.id or node.properties['id'],
then return a UUID instance (or None if neither exists).
"""
id_str = None
if not id_str:
id_str = getattr(node, "id", None)
if hasattr(node, "attributes") and not id_str:
id_str = node.attributes.get("id", None)
id = UUID(id_str) if isinstance(id_str, str) else None
return id

View file

@ -0,0 +1,40 @@
from typing import Optional
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
from enum import Enum
from pydantic import BaseModel, Field, confloat
class CogneeUserInteraction(DataPoint):
"""User - Cognee interaction"""
question: str
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None
class CogneeUserFeedback(DataPoint):
"""User - Cognee Feedback"""
feedback: str
sentiment: str
score: float
belongs_to_set: Optional[NodeSet] = None
class UserFeedbackSentiment(str, Enum):
"""User - User feedback sentiment"""
positive = "positive"
negative = "negative"
neutral = "neutral"
class UserFeedbackEvaluation(BaseModel):
"""User - User feedback evaluation"""
score: confloat(ge=-5, le=5) = Field(
..., description="Sentiment score from -5 (negative) to +5 (positive)"
)
evaluation: UserFeedbackSentiment

View file

@ -3,6 +3,8 @@ import json
import asyncio
from uuid import UUID
from typing import Callable, List, Optional, Type, Union
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.context_global_variables import set_database_global_context_variables
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
@ -38,6 +40,8 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
):
"""
@ -57,7 +61,14 @@ async def search(
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
query_text, query_type, user, dataset_ids, system_prompt_path, top_k
query_text=query_text,
query_type=query_type,
user=user,
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
top_k=top_k,
save_interaction=save_interaction,
last_k=last_k,
)
query = await log_query(query_text, query_type.value, user.id)
@ -70,6 +81,8 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
await log_result(
@ -91,6 +104,8 @@ async def specific_search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -104,28 +119,33 @@ async def specific_search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
}
# If the query type is FEELING_LUCKY, select the search type intelligently
@ -153,6 +173,8 @@ async def authorized_search(
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -166,7 +188,14 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_datasets, query_text, query_type, user, system_prompt_path, top_k
search_datasets,
query_text,
query_type,
user,
system_prompt_path,
top_k,
save_interaction,
last_k=last_k,
)
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -181,17 +210,27 @@ async def specific_search_by_context(
user: User,
system_prompt_path: str,
top_k: int,
save_interaction: bool = False,
last_k: Optional[int] = None,
):
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
Not to be used outside of active access control mode.
"""
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
async def _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
):
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
save_interaction=save_interaction,
last_k=last_k,
)
return {
"search_result": search_results,
@ -203,7 +242,9 @@ async def specific_search_by_context(
tasks = []
for dataset in search_datasets:
tasks.append(
_search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k)
_search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
)
)
return await asyncio.gather(*tasks)

View file

@ -14,3 +14,4 @@ class SearchType(Enum):
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
FEELING_LUCKY = "FEELING_LUCKY"
FEEDBACK = "FEEDBACK"

View file

@ -10,7 +10,37 @@ from cognee.tasks.storage.exceptions import (
)
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
async def add_data_points(
data_points: List[DataPoint], update_edge_collection: bool = True
) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
This function parallelizes the graph extraction for each data point,
merges the resulting nodes and edges, and ensures uniqueness before
committing them to the underlying graph engine. It also updates the
associated retrieval indices for nodes and (optionally) edges.
Args:
data_points (List[DataPoint]):
A list of data points to process and insert into the graph.
update_edge_collection (bool, optional):
Whether to update the edge index after adding edges.
Defaults to True.
Returns:
List[DataPoint]:
The original list of data points after processing and insertion.
Side Effects:
- Calls `get_graph_from_model` concurrently for each data point.
- Deduplicates nodes and edges across all results.
- Updates the node index via `index_data_points`.
- Inserts nodes and edges into the graph engine.
- Optionally updates the edge index via `index_graph_edges`.
"""
if not isinstance(data_points, list):
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
if not all(isinstance(dp, DataPoint) for dp in data_points):
@ -48,7 +78,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
# This step has to happen after adding nodes and edges because we query the graph.
await index_graph_edges()
if update_edge_collection:
await index_graph_edges()
return data_points

View file

@ -94,12 +94,12 @@ async def main():
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
context_nonempty, _ = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
context_empty, _ = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")

View file

@ -98,12 +98,12 @@ async def main():
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
context_nonempty, _ = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
context_empty, _ = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")

View file

@ -4,6 +4,7 @@ import pathlib
from dns.e164 import query
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
@ -18,6 +19,7 @@ from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.engine.models import NodeSet
from collections import Counter
logger = get_logger()
@ -44,16 +46,16 @@ async def main():
await cognee.cognify([dataset_name])
context_gk = await GraphCompletionRetriever().get_context(
context_gk, _ = await GraphCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_cot = await GraphCompletionCotRetriever().get_context(
context_gk_cot, _ = await GraphCompletionCotRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
@ -112,18 +114,33 @@ async def main():
completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_cot = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_ext = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
)
completion_sum = await cognee.search(
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This answer was great",
last_k=1,
)
for name, completion in [
@ -141,6 +158,108 @@ async def main():
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
# Assert there are exactly 4 CogneeUserInteraction nodes.
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
)
# Assert there is exactly two CogneeUserFeedback nodes.
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
)
# Assert there is exactly two NodeSet.
assert type_counts.get("NodeSet", 0) == 2, (
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
)
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
)
# Assert that there are exactly 2 'gives_feedback_to' edges.
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
)
# Assert that there are at least 6 'belongs_to_set' edges.
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
)
nodes = graph[0]
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes:
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
)
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
)
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text_1, dataset_name)
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This was the best answer I've ever seen",
last_k=1,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="Wow the correctness of this answer blows my mind",
last_k=1,
)
graph = await graph_engine.get_graph_data()
edges = graph[1]
for from_node, to_node, relationship_name, properties in edges:
if relationship_name == "used_graph_element_to_answer":
assert properties["feedback_weight"] >= 6, (
"Feedback weight calculation is not correct, it should be more then 6."
)
if __name__ == "__main__":
import asyncio

View file

@ -51,7 +51,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
retriever = GraphCompletionContextExtensionRetriever()
context = await retriever.get_context("Who works at Canva?")
context, _ = await retriever.get_context("Who works at Canva?")
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
@ -129,7 +129,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
print(context)
@ -167,7 +167,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
await setup()
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")

View file

@ -47,7 +47,7 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever()
context = await retriever.get_context("Who works at Canva?")
context, _ = await retriever.get_context("Who works at Canva?")
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
@ -124,7 +124,7 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever(top_k=20)
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
print(context)
@ -162,7 +162,7 @@ class TestGraphCompletionCoTRetriever:
await setup()
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")

View file

@ -67,7 +67,7 @@ class TestGraphCompletionRetriever:
retriever = GraphCompletionRetriever()
context = await retriever.get_context("Who works at Canva?")
context, _ = await retriever.get_context("Who works at Canva?")
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
@ -191,7 +191,7 @@ class TestGraphCompletionRetriever:
retriever = GraphCompletionRetriever(top_k=20)
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
print(context)
@ -222,5 +222,5 @@ class TestGraphCompletionRetriever:
await setup()
context = await retriever.get_context("Who works at Figma?")
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"

View file

@ -65,6 +65,8 @@ async def test_search(
top_k=10,
node_type=None,
node_name=None,
save_interaction=False,
last_k=None,
)
# Verify result logging