testing another approach

This commit is contained in:
hajdul88 2025-12-12 12:07:04 +01:00
parent ed21432942
commit eaf29f2e52

View file

@ -1,6 +1,8 @@
import pathlib import pathlib
import os import os
import asyncio
import pytest import pytest
import pytest_asyncio
from collections import Counter from collections import Counter
import cognee import cognee
@ -28,6 +30,20 @@ from cognee.modules.users.methods import get_default_user
logger = get_logger() logger = get_logger()
@pytest.fixture(scope="session")
def event_loop():
"""Use a single asyncio event loop for this test module.
This helps avoid "Future attached to a different loop" when running multiple async
tests that share clients/engines.
"""
loop = asyncio.new_event_loop()
try:
yield loop
finally:
loop.close()
async def setup_test_environment(): async def setup_test_environment():
"""Helper function to set up test environment with data, cognify, and triplet embeddings.""" """Helper function to set up test environment with data, cognify, and triplet embeddings."""
# This test runs for multiple db settings, to run this locally set the corresponding db envs # This test runs for multiple db settings, to run this locally set the corresponding db envs
@ -143,12 +159,11 @@ async def setup_test_environment_for_feedback():
} }
@pytest.mark.asyncio @pytest_asyncio.fixture(scope="session")
async def test_search_db(): async def e2e_state():
"""Run all search-db checks in one test. """Compute E2E artifacts once; tests only assert.
This intentionally keeps everything in a single test to avoid event loop churn between This avoids repeating expensive setup and LLM calls across multiple tests.
tests when running against deployed databases.
""" """
await setup_test_environment() await setup_test_environment()
@ -161,96 +176,36 @@ async def test_search_db():
collection_name="Triplet_text", query_text="Test", limit=None collection_name="Triplet_text", query_text="Test", limit=None
) )
assert len(edges) == len(collection), (
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
)
# --- Retriever contexts --- # --- Retriever contexts ---
query = "Next to which country is Germany located?" query = "Next to which country is Germany located?"
context_gk = await GraphCompletionRetriever().get_context(query=query) contexts = {
context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query) "graph_completion": await GraphCompletionRetriever().get_context(query=query),
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query) "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query) "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
context_chunks = await ChunksRetriever(top_k=5).get_context(query=query) query=query
context_summaries = await SummariesRetriever(top_k=5).get_context(query=query) ),
context_rag = await CompletionRetriever(top_k=3).get_context(query=query) "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
context_temporal = await TemporalRetriever(top_k=5).get_context(query=query) query=query
context_triplet = await TripletRetriever().get_context(query=query) ),
"chunks": await ChunksRetriever(top_k=5).get_context(query=query),
for name, context in [ "summaries": await SummariesRetriever(top_k=5).get_context(query=query),
("GraphCompletionRetriever", context_gk), "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
("GraphCompletionCotRetriever", context_gk_cot), "temporal": await TemporalRetriever(top_k=5).get_context(query=query),
("GraphCompletionContextExtensionRetriever", context_gk_ext), "triplet": await TripletRetriever().get_context(query=query),
("GraphSummaryCompletionRetriever", context_gk_sum), }
]:
assert isinstance(context, list), f"{name}: Context should be a list"
assert len(context) > 0, f"{name}: Context should not be empty"
context_text = await resolve_edges_to_text(context)
lower = context_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
lower_triplet = context_triplet.lower()
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
)
assert isinstance(context_chunks, list), "ChunksRetriever: Context should be a list"
assert context_chunks, "ChunksRetriever: Context should not be empty"
chunks_text = "\n".join(str(item.get("text", "")) for item in context_chunks).lower()
assert "germany" in chunks_text or "netherlands" in chunks_text, (
"ChunksRetriever: Context did not contain 'germany' or 'netherlands'; "
f"got: {context_chunks!r}"
)
assert isinstance(context_summaries, list), "SummariesRetriever: Context should be a list"
assert context_summaries, "SummariesRetriever: Context should not be empty"
assert any(str(item.get("text", "")).strip() for item in context_summaries), (
"SummariesRetriever: Expected at least one non-empty 'text' field in summary payloads"
)
assert isinstance(context_rag, str), "CompletionRetriever: Context should be a string"
assert context_rag.strip(), "CompletionRetriever: Context should not be empty"
assert isinstance(context_temporal, str), "TemporalRetriever: Context should be a string"
assert context_temporal.strip(), "TemporalRetriever: Context should not be empty"
# --- Retriever triplets + vector distance validation --- # --- Retriever triplets + vector distance validation ---
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query) triplets = {
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query) "graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query) "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query) "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
query=query
for name, triplets in [ ),
("GraphCompletionRetriever", triplets_gk), "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
("GraphCompletionCotRetriever", triplets_gk_cot), query=query
("GraphCompletionContextExtensionRetriever", triplets_gk_ext), ),
("GraphSummaryCompletionRetriever", triplets_gk_sum), }
]:
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), (
f"{name}: vector_distance should be float, got {type(distance)}"
)
assert 0 <= distance <= 1, (
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node1_distance <= 1, (
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node2_distance <= 1, (
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
)
# --- Search operations + graph side effects --- # --- Search operations + graph side effects ---
completion_gk = await cognee.search( completion_gk = await cognee.search(
@ -310,139 +265,34 @@ async def test_search_db():
last_k=1, last_k=1,
) )
from cognee.context_global_variables import backend_access_control_enabled # Snapshot after all E2E operations above (used by assertion-only tests).
graph_snapshot = await (await get_graph_engine()).get_graph_data()
for name, search_results in [ return {
("GRAPH_COMPLETION", completion_gk), "graph_edges": edges,
("GRAPH_COMPLETION_COT", completion_cot), "triplet_collection": collection,
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), "vector_collection_edges_count": len(collection),
("GRAPH_SUMMARY_COMPLETION", completion_sum), "graph_edges_count": len(edges),
("TRIPLET_COMPLETION", completion_triplet), "contexts": contexts,
("RAG_COMPLETION", completion_rag), "triplets": triplets,
("TEMPORAL", completion_temporal), "search_results": {
]: "graph_completion": completion_gk,
assert isinstance(search_results, list), f"{name}: should return a list" "graph_completion_cot": completion_cot,
assert len(search_results) == 1, ( "graph_completion_context_extension": completion_ext,
f"{name}: expected single-element list, got {len(search_results)}" "graph_summary_completion": completion_sum,
) "triplet_completion": completion_triplet,
"chunks": completion_chunks,
"summaries": completion_summaries,
"rag_completion": completion_rag,
"temporal": completion_temporal,
},
"graph_snapshot": graph_snapshot,
}
if backend_access_control_enabled():
wrapper = search_results[0]
assert isinstance(wrapper, dict), (
f"{name}: expected wrapper dict in access control mode"
)
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
assert wrapper.get("dataset_name") == "test_dataset", (
f"{name}: unexpected dataset_name {wrapper.get('dataset_name')!r}"
)
assert "graphs" in wrapper, f"{name}: missing graphs key in wrapper"
text = search_results[0]["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
for name, search_results in [ @pytest_asyncio.fixture(scope="session")
("CHUNKS", completion_chunks), async def feedback_state():
("SUMMARIES", completion_summaries), """Feedback-weight scenario computed once (fresh environment)."""
]:
assert isinstance(search_results, list), f"{name}: should return a list"
assert search_results, f"{name}: should not be empty"
first = search_results[0]
assert isinstance(first, dict), f"{name}: expected dict entries, got {type(first).__name__}"
payloads = search_results
if "search_result" in first and "text" not in first:
assert first.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
assert first.get("dataset_name") == "test_dataset", (
f"{name}: unexpected dataset_name {first.get('dataset_name')!r}"
)
assert "graphs" in first, f"{name}: missing graphs key in wrapper"
payloads = (first.get("search_result") or [None])[0]
assert isinstance(payloads, list), (
f"{name}: expected list payloads, got {type(payloads).__name__}"
)
assert payloads, f"{name}: expected non-empty payload list"
assert isinstance(payloads[0], dict), f"{name}: expected dict payloads"
assert str(payloads[0].get("text", "")).strip(), f"{name}: missing non-empty 'text'"
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
# Assert there are exactly 4 CogneeUserInteraction nodes.
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
"Expected exactly four CogneeUserInteraction nodes, "
f"but found {type_counts.get('CogneeUserInteraction', 0)}"
)
# Assert there is exactly two CogneeUserFeedback nodes.
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
"Expected exactly two CogneeUserFeedback nodes, "
f"but found {type_counts.get('CogneeUserFeedback', 0)}"
)
# Assert there is exactly two NodeSet.
assert type_counts.get("NodeSet", 0) == 2, (
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
)
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
"Expected at least ten 'used_graph_element_to_answer' edges, but found "
f"{edge_type_counts.get('used_graph_element_to_answer', 0)}"
)
# Assert that there are exactly 2 'gives_feedback_to' edges.
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
"Expected exactly two 'gives_feedback_to' edges, but found "
f"{edge_type_counts.get('gives_feedback_to', 0)}"
)
# Assert that there are at least 6 'belongs_to_set' edges.
assert edge_type_counts.get("belongs_to_set", 0) >= 6, (
"Expected at least six 'belongs_to_set' edges, but found "
f"{edge_type_counts.get('belongs_to_set', 0)}"
)
# Node field validation on the same graph produced above
nodes = graph[0]
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes:
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys()), (
f"Node {node_id} is missing fields: "
f"{required_fields_user_interaction - set(data.keys())}"
)
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
)
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
# --- Feedback weight calculation (run in fresh environment) ---
await setup_test_environment_for_feedback() await setup_test_environment_for_feedback()
await cognee.search( await cognee.search(
@ -450,13 +300,11 @@ async def test_search_db():
query_text="Next to which country is Germany located?", query_text="Next to which country is Germany located?",
save_interaction=True, save_interaction=True,
) )
await cognee.search( await cognee.search(
query_type=SearchType.FEEDBACK, query_type=SearchType.FEEDBACK,
query_text="This was the best answer I've ever seen", query_text="This was the best answer I've ever seen",
last_k=1, last_k=1,
) )
await cognee.search( await cognee.search(
query_type=SearchType.FEEDBACK, query_type=SearchType.FEEDBACK,
query_text="Wow the correctness of this answer blows my mind", query_text="Wow the correctness of this answer blows my mind",
@ -465,9 +313,163 @@ async def test_search_db():
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data() graph = await graph_engine.get_graph_data()
edges = graph[1] return {"graph_snapshot": graph}
for from_node, to_node, relationship_name, properties in edges:
@pytest.mark.asyncio
async def test_e2e_graph_vector_consistency(e2e_state):
assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
@pytest.mark.asyncio
async def test_e2e_retriever_contexts(e2e_state):
query = "Next to which country is Germany located?"
contexts = e2e_state["contexts"]
for name in [
"graph_completion",
"graph_completion_cot",
"graph_completion_context_extension",
"graph_summary_completion",
]:
ctx = contexts[name]
assert isinstance(ctx, list), f"{name}: Context should be a list"
assert ctx, f"{name}: Context should not be empty"
ctx_text = await resolve_edges_to_text(ctx)
lower = ctx_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
)
triplet_ctx = contexts["triplet"]
assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
assert triplet_ctx.strip(), "triplet: Context should not be empty"
chunks_ctx = contexts["chunks"]
assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
assert chunks_ctx, "chunks: Context should not be empty"
chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
assert "germany" in chunks_text or "netherlands" in chunks_text
summaries_ctx = contexts["summaries"]
assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
assert summaries_ctx, "summaries: Context should not be empty"
assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
rag_ctx = contexts["rag_completion"]
assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
assert rag_ctx.strip(), "rag_completion: Context should not be empty"
temporal_ctx = contexts["temporal"]
assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
assert temporal_ctx.strip(), "temporal: Context should not be empty"
@pytest.mark.asyncio
async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
for name, triplets in e2e_state["triplets"].items():
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), f"{name}: vector_distance should be float"
assert 0 <= distance <= 1
assert 0 <= node1_distance <= 1
assert 0 <= node2_distance <= 1
@pytest.mark.asyncio
async def test_e2e_search_results_and_wrappers(e2e_state):
from cognee.context_global_variables import backend_access_control_enabled
sr = e2e_state["search_results"]
# Completion-like search types: validate wrapper + content
for name in [
"graph_completion",
"graph_completion_cot",
"graph_completion_context_extension",
"graph_summary_completion",
"triplet_completion",
"rag_completion",
"temporal",
]:
search_results = sr[name]
assert isinstance(search_results, list), f"{name}: should return a list"
assert len(search_results) == 1, f"{name}: expected single-element list"
if backend_access_control_enabled():
wrapper = search_results[0]
assert isinstance(wrapper, dict), (
f"{name}: expected wrapper dict in access control mode"
)
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
assert wrapper.get("dataset_name") == "test_dataset"
assert "graphs" in wrapper
text = wrapper["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str) and text.strip()
assert "netherlands" in text.lower()
# Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
for name in ["chunks", "summaries"]:
search_results = sr[name]
assert isinstance(search_results, list), f"{name}: should return a list"
assert search_results, f"{name}: should not be empty"
first = search_results[0]
assert isinstance(first, dict), f"{name}: expected dict entries"
payloads = search_results
if "search_result" in first and "text" not in first:
payloads = (first.get("search_result") or [None])[0]
assert isinstance(payloads, list) and payloads
assert isinstance(payloads[0], dict)
assert str(payloads[0].get("text", "")).strip()
@pytest.mark.asyncio
async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
graph = e2e_state["graph_snapshot"]
nodes, edges = graph
type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
edge_type_counts = Counter(edge_type[2] for edge_type in edges)
assert type_counts.get("CogneeUserInteraction", 0) == 4
assert type_counts.get("CogneeUserFeedback", 0) == 2
assert type_counts.get("NodeSet", 0) == 2
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
assert edge_type_counts.get("gives_feedback_to", 0) == 2
assert edge_type_counts.get("belongs_to_set", 0) >= 6
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes:
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys())
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip()
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys())
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip()
@pytest.mark.asyncio
async def test_e2e_feedback_weight_calculation(feedback_state):
_nodes, edges = feedback_state["graph_snapshot"]
for _from_node, _to_node, relationship_name, properties in edges:
if relationship_name == "used_graph_element_to_answer": if relationship_name == "used_graph_element_to_answer":
assert properties["feedback_weight"] >= 6, ( assert properties["feedback_weight"] >= 6, (
"Feedback weight calculation is not correct, it should be more then 6." "Feedback weight calculation is not correct, it should be more then 6."