diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 8a99ebd8f..21dc1d3bf 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -4,6 +4,7 @@ import pathlib from dns.e164 import query import cognee +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( @@ -18,6 +19,7 @@ from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.engine.models import NodeSet +from collections import Counter logger = get_logger() @@ -112,18 +114,33 @@ async def main(): completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Next to which country is Germany located?", + save_interaction=True, ) completion_cot = await cognee.search( query_type=SearchType.GRAPH_COMPLETION_COT, query_text="Next to which country is Germany located?", + save_interaction=True, ) completion_ext = await cognee.search( query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, query_text="Next to which country is Germany located?", + save_interaction=True, ) + + feedback_sum_1 = await cognee.search( + query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1 + ) + completion_sum = await cognee.search( query_type=SearchType.GRAPH_SUMMARY_COMPLETION, query_text="Next to which country is Germany located?", + save_interaction=True, + ) + + feedback_sum_2 = await cognee.search( + query_type=SearchType.FEEDBACK, + query_text="This answer was great", + last_k=1, ) for name, completion in [ @@ -141,6 +158,71 @@ async def main(): f"{name}: expected 'netherlands' in result, got: {text!r}" ) + graph_engine = await get_graph_engine() + graph = await graph_engine.get_graph_data() + + type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) + + edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) + + # Assert there are exactly 4 CogneeUserInteraction nodes. + assert type_counts.get("CogneeUserInteraction", 0) == 4, ( + f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}" + ) + + # Assert there is exactly two CogneeUserFeedback nodes. + assert type_counts.get("CogneeUserFeedback", 0) == 2, ( + f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}" + ) + + # Assert there is exactly two NodeSet. + assert type_counts.get("NodeSet", 0) == 2, ( + f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}" + ) + + # Assert that there are at least 10 'used_graph_element_to_answer' edges. + assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, ( + f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}" + ) + + # Assert that there are exactly 2 'gives_feedback_to' edges. + assert edge_type_counts.get("gives_feedback_to", 0) == 2, ( + f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}" + ) + + # Assert that there are at least 6 'belongs_to_set' edges. + assert edge_type_counts.get("belongs_to_set", 0) == 6, ( + f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" + ) + + nodes = graph[0] + + required_fields_user_interaction = {"question", "answer", "context"} + required_fields_feedback = {"feedback", "sentiment"} + + for node_id, data in nodes: # nodes = your list + if data.get("type") == "CogneeUserInteraction": + assert required_fields_user_interaction.issubset(data.keys()), ( + f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" + ) + + for field in required_fields_user_interaction: + value = data[field] + assert isinstance(value, str) and value.strip(), ( + f"Node {node_id} has invalid value for '{field}': {value!r}" + ) + + if data.get("type") == "CogneeUserFeedback": + assert required_fields_feedback.issubset(data.keys()), ( + f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}" + ) + + for field in required_fields_feedback: + value = data[field] + assert isinstance(value, str) and value.strip(), ( + f"Node {node_id} has invalid value for '{field}': {value!r}" + ) + if __name__ == "__main__": import asyncio