feat: adds kuzu and neo4j tests for feedback and interaction features
This commit is contained in:
parent
372181d8c1
commit
fcdee16f69
1 changed files with 82 additions and 0 deletions
|
|
@ -4,6 +4,7 @@ import pathlib
|
|||
from dns.e164 import query
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
|
|
@ -18,6 +19,7 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from collections import Counter
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -112,18 +114,33 @@ async def main():
|
|||
completion_gk = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
completion_cot = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
completion_ext = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
|
||||
feedback_sum_1 = await cognee.search(
|
||||
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
|
||||
)
|
||||
|
||||
completion_sum = await cognee.search(
|
||||
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
|
||||
feedback_sum_2 = await cognee.search(
|
||||
query_type=SearchType.FEEDBACK,
|
||||
query_text="This answer was great",
|
||||
last_k=1,
|
||||
)
|
||||
|
||||
for name, completion in [
|
||||
|
|
@ -141,6 +158,71 @@ async def main():
|
|||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
||||
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
||||
|
||||
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
||||
|
||||
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
||||
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
||||
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
||||
)
|
||||
|
||||
# Assert there is exactly two CogneeUserFeedback nodes.
|
||||
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
|
||||
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
|
||||
)
|
||||
|
||||
# Assert there is exactly two NodeSet.
|
||||
assert type_counts.get("NodeSet", 0) == 2, (
|
||||
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
|
||||
)
|
||||
|
||||
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
|
||||
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
|
||||
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
|
||||
)
|
||||
|
||||
# Assert that there are exactly 2 'gives_feedback_to' edges.
|
||||
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
|
||||
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
|
||||
)
|
||||
|
||||
# Assert that there are at least 6 'belongs_to_set' edges.
|
||||
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
|
||||
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
||||
)
|
||||
|
||||
nodes = graph[0]
|
||||
|
||||
required_fields_user_interaction = {"question", "answer", "context"}
|
||||
required_fields_feedback = {"feedback", "sentiment"}
|
||||
|
||||
for node_id, data in nodes: # nodes = your list
|
||||
if data.get("type") == "CogneeUserInteraction":
|
||||
assert required_fields_user_interaction.issubset(data.keys()), (
|
||||
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
|
||||
)
|
||||
|
||||
for field in required_fields_user_interaction:
|
||||
value = data[field]
|
||||
assert isinstance(value, str) and value.strip(), (
|
||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
||||
)
|
||||
|
||||
if data.get("type") == "CogneeUserFeedback":
|
||||
assert required_fields_feedback.issubset(data.keys()), (
|
||||
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
|
||||
)
|
||||
|
||||
for field in required_fields_feedback:
|
||||
value = data[field]
|
||||
assert isinstance(value, str) and value.strip(), (
|
||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue