adds e2e tests (old test broken into separate tests)
This commit is contained in:
parent
84058d4525
commit
0bef029e34
1 changed files with 287 additions and 0 deletions
|
|
@ -77,6 +77,12 @@ async def setup_search_db_environment():
|
||||||
gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query)
|
gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query)
|
||||||
triplet_retriever_context = await TripletRetriever().get_context(query=query)
|
triplet_retriever_context = await TripletRetriever().get_context(query=query)
|
||||||
|
|
||||||
|
# Pre-compute triplets for test_retriever_triplets
|
||||||
|
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query)
|
||||||
|
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query)
|
||||||
|
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query)
|
||||||
|
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"dataset_name": dataset_name,
|
"dataset_name": dataset_name,
|
||||||
"text_1": text_1,
|
"text_1": text_1,
|
||||||
|
|
@ -88,6 +94,10 @@ async def setup_search_db_environment():
|
||||||
"gk_ext_retriever_context": gk_ext_retriever_context,
|
"gk_ext_retriever_context": gk_ext_retriever_context,
|
||||||
"gk_sum_retriever_context": gk_sum_retriever_context,
|
"gk_sum_retriever_context": gk_sum_retriever_context,
|
||||||
"triplet_retriever_context": triplet_retriever_context,
|
"triplet_retriever_context": triplet_retriever_context,
|
||||||
|
"triplets_gk": triplets_gk,
|
||||||
|
"triplets_gk_cot": triplets_gk_cot,
|
||||||
|
"triplets_gk_ext": triplets_gk_ext,
|
||||||
|
"triplets_gk_sum": triplets_gk_sum,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Fixture teardown: pruning data and system")
|
logger.info("Fixture teardown: pruning data and system")
|
||||||
|
|
@ -179,3 +189,280 @@ async def test_retriever_contexts(setup_search_db_environment):
|
||||||
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
||||||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retriever_triplets(setup_search_db_environment):
|
||||||
|
"""Test that retrievers return valid triplets with proper vector distances."""
|
||||||
|
triplets_gk = setup_search_db_environment["triplets_gk"]
|
||||||
|
triplets_gk_cot = setup_search_db_environment["triplets_gk_cot"]
|
||||||
|
triplets_gk_ext = setup_search_db_environment["triplets_gk_ext"]
|
||||||
|
triplets_gk_sum = setup_search_db_environment["triplets_gk_sum"]
|
||||||
|
|
||||||
|
for name, triplets in [
|
||||||
|
("GraphCompletionRetriever", triplets_gk),
|
||||||
|
("GraphCompletionCotRetriever", triplets_gk_cot),
|
||||||
|
("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
|
||||||
|
("GraphSummaryCompletionRetriever", triplets_gk_sum),
|
||||||
|
]:
|
||||||
|
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
||||||
|
assert triplets, f"{name}: Triplets list should not be empty"
|
||||||
|
for edge in triplets:
|
||||||
|
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||||
|
distance = edge.attributes.get("vector_distance")
|
||||||
|
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||||
|
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||||
|
assert isinstance(distance, float), (
|
||||||
|
f"{name}: vector_distance should be float, got {type(distance)}"
|
||||||
|
)
|
||||||
|
assert 0 <= distance <= 1, (
|
||||||
|
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
|
)
|
||||||
|
assert 0 <= node1_distance <= 1, (
|
||||||
|
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
|
)
|
||||||
|
assert 0 <= node2_distance <= 1, (
|
||||||
|
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_operations(setup_search_db_environment):
|
||||||
|
"""Test different search types and verify results contain expected content."""
|
||||||
|
completion_gk = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="Where is germany located, next to which country?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
completion_cot = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||||
|
query_text="What is the country next to germany??",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
completion_ext = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||||
|
query_text="What is the name of the country next to germany",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_sum = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
completion_triplet = await cognee.search(
|
||||||
|
query_type=SearchType.TRIPLET_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text="This answer was great",
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
for name, search_results in [
|
||||||
|
("GRAPH_COMPLETION", completion_gk),
|
||||||
|
("GRAPH_COMPLETION_COT", completion_cot),
|
||||||
|
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||||
|
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||||
|
("TRIPLET_COMPLETION", completion_triplet),
|
||||||
|
]:
|
||||||
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||||
|
assert len(search_results) == 1, (
|
||||||
|
f"{name}: expected single-element list, got {len(search_results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend_access_control_enabled():
|
||||||
|
text = search_results[0]["search_result"][0]
|
||||||
|
else:
|
||||||
|
text = search_results[0]
|
||||||
|
assert isinstance(text, str), f"{name}: element should be a string"
|
||||||
|
assert text.strip(), f"{name}: string should not be empty"
|
||||||
|
assert "netherlands" in text.lower(), (
|
||||||
|
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_node_and_edge_counts(setup_search_db_environment):
|
||||||
|
"""Test that graph contains expected node and edge counts after search operations."""
|
||||||
|
# First perform searches to create interaction and feedback nodes
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="Where is germany located, next to which country?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||||
|
query_text="What is the country next to germany??",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||||
|
query_text="What is the name of the country next to germany",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.TRIPLET_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text="This answer was great",
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_engine = setup_search_db_environment["graph_engine"]
|
||||||
|
graph = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
||||||
|
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
||||||
|
|
||||||
|
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
||||||
|
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
||||||
|
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert there is exactly two CogneeUserFeedback nodes.
|
||||||
|
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
|
||||||
|
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert there is exactly two NodeSet.
|
||||||
|
assert type_counts.get("NodeSet", 0) == 2, (
|
||||||
|
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
|
||||||
|
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
|
||||||
|
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that there are exactly 2 'gives_feedback_to' edges.
|
||||||
|
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
|
||||||
|
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that there are at least 6 'belongs_to_set' edges.
|
||||||
|
assert edge_type_counts.get("belongs_to_set", 0) >= 6, (
|
||||||
|
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_node_field_validation(setup_search_db_environment):
|
||||||
|
"""Test that user interaction and feedback nodes have required fields with valid values."""
|
||||||
|
# First perform searches to create interaction and feedback nodes
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="Where is germany located, next to which country?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||||
|
query_text="What is the country next to germany??",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||||
|
query_text="What is the name of the country next to germany",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.TRIPLET_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text="This answer was great",
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_engine = setup_search_db_environment["graph_engine"]
|
||||||
|
graph = await graph_engine.get_graph_data()
|
||||||
|
nodes = graph[0]
|
||||||
|
|
||||||
|
required_fields_user_interaction = {"question", "answer", "context"}
|
||||||
|
required_fields_feedback = {"feedback", "sentiment"}
|
||||||
|
|
||||||
|
for node_id, data in nodes:
|
||||||
|
if data.get("type") == "CogneeUserInteraction":
|
||||||
|
assert required_fields_user_interaction.issubset(data.keys()), (
|
||||||
|
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field in required_fields_user_interaction:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip(), (
|
||||||
|
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get("type") == "CogneeUserFeedback":
|
||||||
|
assert required_fields_feedback.issubset(data.keys()), (
|
||||||
|
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field in required_fields_feedback:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip(), (
|
||||||
|
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_feedback_weight_calculation(setup_search_db_environment_for_feedback):
|
||||||
|
"""Test that feedback weight is correctly calculated after multiple positive feedbacks."""
|
||||||
|
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text="This was the best answer I've ever seen",
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text="Wow the correctness of this answer blows my mind",
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
graph = await graph_engine.get_graph_data()
|
||||||
|
edges = graph[1]
|
||||||
|
|
||||||
|
for from_node, to_node, relationship_name, properties in edges:
|
||||||
|
if relationship_name == "used_graph_element_to_answer":
|
||||||
|
assert properties["feedback_weight"] >= 6, (
|
||||||
|
"Feedback weight calculation is not correct, it should be more then 6."
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue