removes fixtures
This commit is contained in:
parent
0bef029e34
commit
7e0c9f0c91
1 changed files with 92 additions and 73 deletions
|
|
@ -1,9 +1,7 @@
|
|||
import pathlib
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
from cognee.context_global_variables import set_session_user_context_variable
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -25,11 +23,10 @@ from collections import Counter
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_search_db_environment():
|
||||
"""Set up test environment with data, cognify, and triplet embeddings."""
|
||||
async def setup_test_environment():
|
||||
"""Helper function to set up test environment with data, cognify, and triplet embeddings."""
|
||||
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
||||
logger.info("Starting fixture setup: pruning data and system")
|
||||
logger.info("Starting test setup: pruning data and system")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -65,52 +62,15 @@ async def setup_search_db_environment():
|
|||
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
|
||||
logger.info(f"Triplet_text collection row count: {count}")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
query = "Next to which country is Germany located?"
|
||||
gk_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
gk_cot_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
gk_ext_retriever_context = await GraphCompletionContextExtensionRetriever().get_context(
|
||||
query=query
|
||||
)
|
||||
gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query)
|
||||
triplet_retriever_context = await TripletRetriever().get_context(query=query)
|
||||
|
||||
# Pre-compute triplets for test_retriever_triplets
|
||||
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query)
|
||||
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query)
|
||||
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query)
|
||||
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query)
|
||||
|
||||
yield {
|
||||
return {
|
||||
"dataset_name": dataset_name,
|
||||
"text_1": text_1,
|
||||
"explanation_file_path_quantum": explanation_file_path_quantum,
|
||||
"graph_engine": graph_engine,
|
||||
"vector_engine": vector_engine,
|
||||
"gk_retriever_context": gk_retriever_context,
|
||||
"gk_cot_retriever_context": gk_cot_retriever_context,
|
||||
"gk_ext_retriever_context": gk_ext_retriever_context,
|
||||
"gk_sum_retriever_context": gk_sum_retriever_context,
|
||||
"triplet_retriever_context": triplet_retriever_context,
|
||||
"triplets_gk": triplets_gk,
|
||||
"triplets_gk_cot": triplets_gk_cot,
|
||||
"triplets_gk_ext": triplets_gk_ext,
|
||||
"triplets_gk_sum": triplets_gk_sum,
|
||||
}
|
||||
|
||||
logger.info("Fixture teardown: pruning data and system")
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_search_db_environment_for_feedback():
|
||||
"""Set up test environment for feedback weight calculation test."""
|
||||
async def setup_test_environment_for_feedback():
|
||||
"""Helper function to set up test environment for feedback weight calculation test."""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -127,26 +87,22 @@ async def setup_search_db_environment_for_feedback():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
yield {
|
||||
return {
|
||||
"dataset_name": dataset_name,
|
||||
"text_1": text_1,
|
||||
"explanation_file_path_quantum": explanation_file_path_quantum,
|
||||
}
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_vector_engine_consistency(setup_search_db_environment):
|
||||
async def test_graph_vector_engine_consistency():
|
||||
"""Test that graph edges match triplet collection in vector engine."""
|
||||
vector_engine = setup_search_db_environment["vector_engine"]
|
||||
graph_engine = setup_search_db_environment["graph_engine"]
|
||||
await setup_test_environment()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
collection = await vector_engine.search(
|
||||
collection_name="Triplet_text", query_text="Test", limit=None
|
||||
)
|
||||
|
|
@ -155,16 +111,26 @@ async def test_graph_vector_engine_consistency(setup_search_db_environment):
|
|||
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_contexts(setup_search_db_environment):
|
||||
async def test_retriever_contexts():
|
||||
"""Test that all retrievers return valid contexts with expected content."""
|
||||
await setup_test_environment()
|
||||
|
||||
context_gk = setup_search_db_environment["gk_retriever_context"]
|
||||
context_gk_cot = setup_search_db_environment["gk_cot_retriever_context"]
|
||||
context_gk_ext = setup_search_db_environment["gk_ext_retriever_context"]
|
||||
context_gk_sum = setup_search_db_environment["gk_sum_retriever_context"]
|
||||
context_triplet = setup_search_db_environment["triplet_retriever_context"]
|
||||
query = "Next to which country is Germany located?"
|
||||
|
||||
context_gk = await GraphCompletionRetriever().get_context(query=query)
|
||||
context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query)
|
||||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query)
|
||||
context_triplet = await TripletRetriever().get_context(query=query)
|
||||
|
||||
# Test graph-based retrievers (should return lists)
|
||||
for name, context in [
|
||||
|
|
@ -190,14 +156,25 @@ async def test_retriever_contexts(setup_search_db_environment):
|
|||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_triplets(setup_search_db_environment):
|
||||
async def test_retriever_triplets():
|
||||
"""Test that retrievers return valid triplets with proper vector distances."""
|
||||
triplets_gk = setup_search_db_environment["triplets_gk"]
|
||||
triplets_gk_cot = setup_search_db_environment["triplets_gk_cot"]
|
||||
triplets_gk_ext = setup_search_db_environment["triplets_gk_ext"]
|
||||
triplets_gk_sum = setup_search_db_environment["triplets_gk_sum"]
|
||||
await setup_test_environment()
|
||||
|
||||
query = "Next to which country is Germany located?"
|
||||
|
||||
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query)
|
||||
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query)
|
||||
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query)
|
||||
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query)
|
||||
|
||||
for name, triplets in [
|
||||
("GraphCompletionRetriever", triplets_gk),
|
||||
|
|
@ -225,10 +202,19 @@ async def test_retriever_triplets(setup_search_db_environment):
|
|||
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_operations(setup_search_db_environment):
|
||||
async def test_search_operations():
|
||||
"""Test different search types and verify results contain expected content."""
|
||||
await setup_test_environment()
|
||||
|
||||
completion_gk = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Where is germany located, next to which country?",
|
||||
|
|
@ -290,10 +276,19 @@ async def test_search_operations(setup_search_db_environment):
|
|||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_node_and_edge_counts(setup_search_db_environment):
|
||||
async def test_graph_node_and_edge_counts():
|
||||
"""Test that graph contains expected node and edge counts after search operations."""
|
||||
await setup_test_environment()
|
||||
|
||||
# First perform searches to create interaction and feedback nodes
|
||||
await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
|
|
@ -329,7 +324,7 @@ async def test_graph_node_and_edge_counts(setup_search_db_environment):
|
|||
last_k=1,
|
||||
)
|
||||
|
||||
graph_engine = setup_search_db_environment["graph_engine"]
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
||||
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
||||
|
|
@ -365,10 +360,19 @@ async def test_graph_node_and_edge_counts(setup_search_db_environment):
|
|||
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_field_validation(setup_search_db_environment):
|
||||
async def test_node_field_validation():
|
||||
"""Test that user interaction and feedback nodes have required fields with valid values."""
|
||||
await setup_test_environment()
|
||||
|
||||
# First perform searches to create interaction and feedback nodes
|
||||
await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
|
|
@ -404,7 +408,7 @@ async def test_node_field_validation(setup_search_db_environment):
|
|||
last_k=1,
|
||||
)
|
||||
|
||||
graph_engine = setup_search_db_environment["graph_engine"]
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
nodes = graph[0]
|
||||
|
||||
|
|
@ -434,10 +438,18 @@ async def test_node_field_validation(setup_search_db_environment):
|
|||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_weight_calculation(setup_search_db_environment_for_feedback):
|
||||
async def test_feedback_weight_calculation():
|
||||
"""Test that feedback weight is correctly calculated after multiple positive feedbacks."""
|
||||
await setup_test_environment_for_feedback()
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
|
|
@ -466,3 +478,10 @@ async def test_feedback_weight_calculation(setup_search_db_environment_for_feedb
|
|||
assert properties["feedback_weight"] >= 6, (
|
||||
"Feedback weight calculation is not correct, it should be more then 6."
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue