From 7e0c9f0c91e6249e6da1c2c2c056fd3463f28952 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 11 Dec 2025 19:01:37 +0100 Subject: [PATCH] removes fixtures --- cognee/tests/test_search_db.py | 165 ++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 73 deletions(-) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 285971c56..1c4941b55 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,9 +1,7 @@ import pathlib import os import pytest -import pytest_asyncio import cognee -from cognee.context_global_variables import set_session_user_context_variable from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -25,11 +23,10 @@ from collections import Counter logger = get_logger() -@pytest_asyncio.fixture -async def setup_search_db_environment(): - """Set up test environment with data, cognify, and triplet embeddings.""" +async def setup_test_environment(): + """Helper function to set up test environment with data, cognify, and triplet embeddings.""" # This test runs for multiple db settings, to run this locally set the corresponding db envs - logger.info("Starting fixture setup: pruning data and system") + logger.info("Starting test setup: pruning data and system") await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) @@ -65,52 +62,15 @@ async def setup_search_db_environment(): count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" logger.info(f"Triplet_text collection row count: {count}") - graph_engine = await get_graph_engine() - - vector_engine = get_vector_engine() - query = "Next to which country is Germany located?" - gk_retriever_context = await GraphCompletionCotRetriever().get_context(query=query) - gk_cot_retriever_context = await GraphCompletionCotRetriever().get_context(query=query) - gk_ext_retriever_context = await GraphCompletionContextExtensionRetriever().get_context( - query=query - ) - gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query) - triplet_retriever_context = await TripletRetriever().get_context(query=query) - - # Pre-compute triplets for test_retriever_triplets - triplets_gk = await GraphCompletionRetriever().get_triplets(query=query) - triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query) - triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query) - triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query) - - yield { + return { "dataset_name": dataset_name, "text_1": text_1, "explanation_file_path_quantum": explanation_file_path_quantum, - "graph_engine": graph_engine, - "vector_engine": vector_engine, - "gk_retriever_context": gk_retriever_context, - "gk_cot_retriever_context": gk_cot_retriever_context, - "gk_ext_retriever_context": gk_ext_retriever_context, - "gk_sum_retriever_context": gk_sum_retriever_context, - "triplet_retriever_context": triplet_retriever_context, - "triplets_gk": triplets_gk, - "triplets_gk_cot": triplets_gk_cot, - "triplets_gk_ext": triplets_gk_ext, - "triplets_gk_sum": triplets_gk_sum, } - logger.info("Fixture teardown: pruning data and system") - try: - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - except Exception: - pass - -@pytest_asyncio.fixture -async def setup_search_db_environment_for_feedback(): - """Set up test environment for feedback weight calculation test.""" +async def setup_test_environment_for_feedback(): + """Helper function to set up test environment for feedback weight calculation test.""" await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) @@ -127,26 +87,22 @@ async def setup_search_db_environment_for_feedback(): await cognee.cognify([dataset_name]) - yield { + return { "dataset_name": dataset_name, "text_1": text_1, "explanation_file_path_quantum": explanation_file_path_quantum, } - try: - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - except Exception: - pass - @pytest.mark.asyncio -async def test_graph_vector_engine_consistency(setup_search_db_environment): +async def test_graph_vector_engine_consistency(): """Test that graph edges match triplet collection in vector engine.""" - vector_engine = setup_search_db_environment["vector_engine"] - graph_engine = setup_search_db_environment["graph_engine"] + await setup_test_environment() + graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() + + vector_engine = get_vector_engine() collection = await vector_engine.search( collection_name="Triplet_text", query_text="Test", limit=None ) @@ -155,16 +111,26 @@ async def test_graph_vector_engine_consistency(setup_search_db_environment): f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_retriever_contexts(setup_search_db_environment): +async def test_retriever_contexts(): """Test that all retrievers return valid contexts with expected content.""" + await setup_test_environment() - context_gk = setup_search_db_environment["gk_retriever_context"] - context_gk_cot = setup_search_db_environment["gk_cot_retriever_context"] - context_gk_ext = setup_search_db_environment["gk_ext_retriever_context"] - context_gk_sum = setup_search_db_environment["gk_sum_retriever_context"] - context_triplet = setup_search_db_environment["triplet_retriever_context"] + query = "Next to which country is Germany located?" + + context_gk = await GraphCompletionRetriever().get_context(query=query) + context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query) + context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query) + context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query) + context_triplet = await TripletRetriever().get_context(query=query) # Test graph-based retrievers (should return lists) for name, context in [ @@ -190,14 +156,25 @@ async def test_retriever_contexts(setup_search_db_environment): f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_retriever_triplets(setup_search_db_environment): +async def test_retriever_triplets(): """Test that retrievers return valid triplets with proper vector distances.""" - triplets_gk = setup_search_db_environment["triplets_gk"] - triplets_gk_cot = setup_search_db_environment["triplets_gk_cot"] - triplets_gk_ext = setup_search_db_environment["triplets_gk_ext"] - triplets_gk_sum = setup_search_db_environment["triplets_gk_sum"] + await setup_test_environment() + + query = "Next to which country is Germany located?" + + triplets_gk = await GraphCompletionRetriever().get_triplets(query=query) + triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query) + triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(query=query) + triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(query=query) for name, triplets in [ ("GraphCompletionRetriever", triplets_gk), @@ -225,10 +202,19 @@ async def test_retriever_triplets(setup_search_db_environment): f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_search_operations(setup_search_db_environment): +async def test_search_operations(): """Test different search types and verify results contain expected content.""" + await setup_test_environment() + completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Where is germany located, next to which country?", @@ -290,10 +276,19 @@ async def test_search_operations(setup_search_db_environment): f"{name}: expected 'netherlands' in result, got: {text!r}" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_graph_node_and_edge_counts(setup_search_db_environment): +async def test_graph_node_and_edge_counts(): """Test that graph contains expected node and edge counts after search operations.""" + await setup_test_environment() + # First perform searches to create interaction and feedback nodes await cognee.search( query_type=SearchType.GRAPH_COMPLETION, @@ -329,7 +324,7 @@ async def test_graph_node_and_edge_counts(setup_search_db_environment): last_k=1, ) - graph_engine = setup_search_db_environment["graph_engine"] + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) @@ -365,10 +360,19 @@ async def test_graph_node_and_edge_counts(setup_search_db_environment): f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_node_field_validation(setup_search_db_environment): +async def test_node_field_validation(): """Test that user interaction and feedback nodes have required fields with valid values.""" + await setup_test_environment() + # First perform searches to create interaction and feedback nodes await cognee.search( query_type=SearchType.GRAPH_COMPLETION, @@ -404,7 +408,7 @@ async def test_node_field_validation(setup_search_db_environment): last_k=1, ) - graph_engine = setup_search_db_environment["graph_engine"] + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() nodes = graph[0] @@ -434,10 +438,18 @@ async def test_node_field_validation(setup_search_db_environment): f"Node {node_id} has invalid value for '{field}': {value!r}" ) + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + @pytest.mark.asyncio -async def test_feedback_weight_calculation(setup_search_db_environment_for_feedback): +async def test_feedback_weight_calculation(): """Test that feedback weight is correctly calculated after multiple positive feedbacks.""" + await setup_test_environment_for_feedback() await cognee.search( query_type=SearchType.GRAPH_COMPLETION, @@ -466,3 +478,10 @@ async def test_feedback_weight_calculation(setup_search_db_environment_for_feedb assert properties["feedback_weight"] >= 6, ( "Feedback weight calculation is not correct, it should be more then 6." ) + + # Cleanup + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass