feat: adds kuzu and neo4j tests for feedback and interaction features

This commit is contained in:
hajdul88 2025-08-19 10:49:01 +02:00
parent 372181d8c1
commit fcdee16f69

View file

@ -4,6 +4,7 @@ import pathlib
from dns.e164 import query
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
@ -18,6 +19,7 @@ from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.engine.models import NodeSet
from collections import Counter
logger = get_logger()
@ -112,18 +114,33 @@ async def main():
completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_cot = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_ext = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
feedback_sum_1 = await cognee.search(
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
)
completion_sum = await cognee.search(
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
feedback_sum_2 = await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This answer was great",
last_k=1,
)
for name, completion in [
@ -141,6 +158,71 @@ async def main():
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
# Assert there are exactly 4 CogneeUserInteraction nodes.
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
)
# Assert there is exactly two CogneeUserFeedback nodes.
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
)
# Assert there is exactly two NodeSet.
assert type_counts.get("NodeSet", 0) == 2, (
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
)
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
)
# Assert that there are exactly 2 'gives_feedback_to' edges.
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
)
# Assert that there are at least 6 'belongs_to_set' edges.
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
)
nodes = graph[0]
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes: # nodes = your list
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
)
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
)
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
if __name__ == "__main__":
import asyncio