From b4aaa7faefce804d9ad6fee93d9907b352206f25 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:59:33 +0100 Subject: [PATCH] chore: retriever test reorganization + adding new tests (smoke e2e) (STEP 1.5) (#1888) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR restructures the end-to-end tests for the multi-database search layer to improve maintainability, consistency, and coverage across supported Python versions and database settings. Key Changes -Migrates the existing E2E tests to pytest for a more standard and extensible testing framework. -Introduces pytest fixtures to centralize and reuse test setup logic. -Implements proper event loop management to support multiple asynchronous pytest tests reliably. -Improves SQLAlchemy handling in tests, ensuring clean setup and teardown of database state. -Extends multi-database E2E test coverage across all supported Python versions. Benefits -Cleaner and more modular test structure. -Reduced duplication and clearer test intent through fixtures. -More reliable async test execution. -Better alignment with our supported Python version matrix. ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Tests** * Expanded end-to-end test suite for the search database with comprehensive setup/teardown, new session-scoped fixtures, and multiple tests validating graph/vector consistency, retriever contexts, triplet metadata, search result shapes, side effects, and feedback-weight behavior. * **Chores** * CI updated to run matrixed test jobs across multiple Python versions and standardize test execution for more consistent, parallelized runs. ✏️ Tip: You can customize this high-level summary in your review settings. --- .github/workflows/search_db_tests.yml | 46 ++- cognee/tests/test_search_db.py | 529 +++++++++++++++++--------- 2 files changed, 374 insertions(+), 201 deletions(-) diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml index 118c1c06c..f0c7817cd 100644 --- a/.github/workflows/search_db_tests.yml +++ b/.github/workflows/search_db_tests.yml @@ -11,12 +11,21 @@ on: type: string default: "all" description: "Which vector databases to test (comma-separated list or 'all')" + python-versions: + required: false + type: string + default: '["3.10", "3.11", "3.12", "3.13"]' + description: "Python versions to test (JSON array)" jobs: run-kuzu-lance-sqlite-search-tests: - name: Search test for Kuzu/LanceDB/Sqlite + name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -26,7 +35,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Dependencies already installed run: echo "Dependencies already installed in setup" @@ -45,13 +54,16 @@ jobs: GRAPH_DATABASE_PROVIDER: 'kuzu' VECTOR_DB_PROVIDER: 'lancedb' DB_PROVIDER: 'sqlite' - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-lance-sqlite-search-tests: - name: Search test for Neo4j/LanceDB/Sqlite + name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }} - + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -61,7 +73,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Setup Neo4j with GDS uses: ./.github/actions/setup_neo4j @@ -88,12 +100,16 @@ jobs: GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-kuzu-pgvector-postgres-search-tests: - name: Search test for Kuzu/PGVector/Postgres + name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -117,7 +133,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Dependencies already installed @@ -143,12 +159,16 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-pgvector-postgres-search-tests: - name: Search test for Neo4j/PGVector/Postgres + name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -172,7 +192,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Setup Neo4j with GDS @@ -205,4 +225,4 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index ba150f813..0916be322 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,5 +1,10 @@ import pathlib import os +import asyncio +import pytest +import pytest_asyncio +from collections import Counter + import cognee from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine @@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.modules.retrieval.triplet_retriever import TripletRetriever from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user -from collections import Counter logger = get_logger() -async def main(): - # This test runs for multiple db settings, to run this locally set the corresponding db envs +async def _reset_engines_and_prune() -> None: + """Reset db engine caches and prune data/system. + + Kept intentionally identical to the inlined setup logic to avoid event loop issues when + using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run. + """ + # Dispose of existing engines and clear caches to ensure fresh instances for each test + try: + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + # Dispose SQLAlchemy engine connection pool if it exists + if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"): + await vector_engine.engine.dispose(close=True) + except Exception: + # Engine might not exist yet + pass + + from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine + from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, + ) + + create_graph_engine.cache_clear() + create_vector_engine.cache_clear() + create_relational_engine.cache_clear() + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - dataset_name = "test_dataset" +async def _seed_default_dataset(dataset_name: str) -> dict: + """Add the shared test dataset contents and run cognify (same steps/order as before).""" text_1 = """Germany is located in europe right next to the Netherlands""" + + logger.info(f"Adding text data to dataset: {dataset_name}") await cognee.add(text_1, dataset_name) explanation_file_path_quantum = os.path.join( pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt" ) + logger.info(f"Adding file data to dataset: {dataset_name}") await cognee.add([explanation_file_path_quantum], dataset_name) + logger.info(f"Running cognify on dataset: {dataset_name}") await cognee.cognify([dataset_name]) + return { + "dataset_name": dataset_name, + "text_1": text_1, + "explanation_file_path_quantum": explanation_file_path_quantum, + } + + +@pytest.fixture(scope="session") +def event_loop(): + """Use a single asyncio event loop for this test module. + + This helps avoid "Future attached to a different loop" when running multiple async + tests that share clients/engines. + """ + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +async def setup_test_environment(): + """Helper function to set up test environment with data, cognify, and triplet embeddings.""" + # This test runs for multiple db settings, to run this locally set the corresponding db envs + + dataset_name = "test_dataset" + logger.info("Starting test setup: pruning data and system") + await _reset_engines_and_prune() + state = await _seed_default_dataset(dataset_name=dataset_name) + user = await get_default_user() from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings + logger.info("Creating triplet embeddings") await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5) + # Check if Triplet_text collection was created + vector_engine = get_vector_engine() + has_collection = await vector_engine.has_collection(collection_name="Triplet_text") + logger.info(f"Triplet_text collection exists after creation: {has_collection}") + + if has_collection: + collection = await vector_engine.get_collection("Triplet_text") + count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" + logger.info(f"Triplet_text collection row count: {count}") + + return state + + +async def setup_test_environment_for_feedback(): + """Helper function to set up test environment for feedback weight calculation test.""" + dataset_name = "test_dataset" + await _reset_engines_and_prune() + return await _seed_default_dataset(dataset_name=dataset_name) + + +@pytest_asyncio.fixture(scope="session") +async def e2e_state(): + """Compute E2E artifacts once; tests only assert. + + This avoids repeating expensive setup and LLM calls across multiple tests. + """ + await setup_test_environment() + + # --- Graph/vector engine consistency --- graph_engine = await get_graph_engine() - nodes, edges = await graph_engine.get_graph_data() + _nodes, edges = await graph_engine.get_graph_data() vector_engine = get_vector_engine() collection = await vector_engine.search( - query_text="Test", limit=None, collection_name="Triplet_text" + collection_name="Triplet_text", query_text="Test", limit=None ) - assert len(edges) == len(collection), ( - f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" - ) + # --- Retriever contexts --- + query = "Next to which country is Germany located?" - context_gk = await GraphCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_cot = await GraphCompletionCotRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_sum = await GraphSummaryCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_triplet = await TripletRetriever().get_context( - query="Next to which country is Germany located?" - ) + contexts = { + "graph_completion": await GraphCompletionRetriever().get_context(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context( + query=query + ), + "chunks": await ChunksRetriever(top_k=5).get_context(query=query), + "summaries": await SummariesRetriever(top_k=5).get_context(query=query), + "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query), + "temporal": await TemporalRetriever(top_k=5).get_context(query=query), + "triplet": await TripletRetriever().get_context(query=query), + } - for name, context in [ - ("GraphCompletionRetriever", context_gk), - ("GraphCompletionCotRetriever", context_gk_cot), - ("GraphCompletionContextExtensionRetriever", context_gk_ext), - ("GraphSummaryCompletionRetriever", context_gk_sum), - ]: - assert isinstance(context, list), f"{name}: Context should be a list" - assert len(context) > 0, f"{name}: Context should not be empty" - - context_text = await resolve_edges_to_text(context) - lower = context_text.lower() - assert "germany" in lower or "netherlands" in lower, ( - f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" - ) - - assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string" - assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty" - lower_triplet = context_triplet.lower() - assert "germany" in lower_triplet or "netherlands" in lower_triplet, ( - f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" - ) - - triplets_gk = await GraphCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - - for name, triplets in [ - ("GraphCompletionRetriever", triplets_gk), - ("GraphCompletionCotRetriever", triplets_gk_cot), - ("GraphCompletionContextExtensionRetriever", triplets_gk_ext), - ("GraphSummaryCompletionRetriever", triplets_gk_sum), - ]: - assert isinstance(triplets, list), f"{name}: Triplets should be a list" - assert triplets, f"{name}: Triplets list should not be empty" - for edge in triplets: - assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), ( - f"{name}: vector_distance should be float, got {type(distance)}" - ) - assert 0 <= distance <= 1, ( - f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node1_distance <= 1, ( - f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node2_distance <= 1, ( - f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" - ) + # --- Retriever triplets + vector distance validation --- + triplets = { + "graph_completion": await GraphCompletionRetriever().get_triplets(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets( + query=query + ), + } + # --- Search operations + graph side effects --- completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Where is germany located, next to which country?", @@ -164,6 +214,26 @@ async def main(): query_text="Next to which country is Germany located?", save_interaction=True, ) + completion_chunks = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="Germany", + save_interaction=False, + ) + completion_summaries = await cognee.search( + query_type=SearchType.SUMMARIES, + query_text="Germany", + save_interaction=False, + ) + completion_rag = await cognee.search( + query_type=SearchType.RAG_COMPLETION, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) + completion_temporal = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) await cognee.search( query_type=SearchType.FEEDBACK, @@ -171,134 +241,217 @@ async def main(): last_k=1, ) - for name, search_results in [ - ("GRAPH_COMPLETION", completion_gk), - ("GRAPH_COMPLETION_COT", completion_cot), - ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), - ("GRAPH_SUMMARY_COMPLETION", completion_sum), - ("TRIPLET_COMPLETION", completion_triplet), - ]: - assert isinstance(search_results, list), f"{name}: should return a list" - assert len(search_results) == 1, ( - f"{name}: expected single-element list, got {len(search_results)}" - ) + # Snapshot after all E2E operations above (used by assertion-only tests). + graph_snapshot = await (await get_graph_engine()).get_graph_data() - from cognee.context_global_variables import backend_access_control_enabled + return { + "graph_edges": edges, + "triplet_collection": collection, + "vector_collection_edges_count": len(collection), + "graph_edges_count": len(edges), + "contexts": contexts, + "triplets": triplets, + "search_results": { + "graph_completion": completion_gk, + "graph_completion_cot": completion_cot, + "graph_completion_context_extension": completion_ext, + "graph_summary_completion": completion_sum, + "triplet_completion": completion_triplet, + "chunks": completion_chunks, + "summaries": completion_summaries, + "rag_completion": completion_rag, + "temporal": completion_temporal, + }, + "graph_snapshot": graph_snapshot, + } - if backend_access_control_enabled(): - text = search_results[0]["search_result"][0] - else: - text = search_results[0] - assert isinstance(text, str), f"{name}: element should be a string" - assert text.strip(), f"{name}: string should not be empty" - assert "netherlands" in text.lower(), ( - f"{name}: expected 'netherlands' in result, got: {text!r}" - ) - graph_engine = await get_graph_engine() - graph = await graph_engine.get_graph_data() - - type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) - - edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) - - # Assert there are exactly 4 CogneeUserInteraction nodes. - assert type_counts.get("CogneeUserInteraction", 0) == 4, ( - f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}" - ) - - # Assert there is exactly two CogneeUserFeedback nodes. - assert type_counts.get("CogneeUserFeedback", 0) == 2, ( - f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}" - ) - - # Assert there is exactly two NodeSet. - assert type_counts.get("NodeSet", 0) == 2, ( - f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}" - ) - - # Assert that there are at least 10 'used_graph_element_to_answer' edges. - assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, ( - f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}" - ) - - # Assert that there are exactly 2 'gives_feedback_to' edges. - assert edge_type_counts.get("gives_feedback_to", 0) == 2, ( - f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}" - ) - - # Assert that there are at least 6 'belongs_to_set' edges. - assert edge_type_counts.get("belongs_to_set", 0) == 6, ( - f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" - ) - - nodes = graph[0] - - required_fields_user_interaction = {"question", "answer", "context"} - required_fields_feedback = {"feedback", "sentiment"} - - for node_id, data in nodes: - if data.get("type") == "CogneeUserInteraction": - assert required_fields_user_interaction.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" - ) - - for field in required_fields_user_interaction: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - if data.get("type") == "CogneeUserFeedback": - assert required_fields_feedback.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}" - ) - - for field in required_fields_feedback: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - await cognee.add(text_1, dataset_name) - - await cognee.add([text], dataset_name) - - await cognee.cognify([dataset_name]) +@pytest_asyncio.fixture(scope="session") +async def feedback_state(): + """Feedback-weight scenario computed once (fresh environment).""" + await setup_test_environment_for_feedback() await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Next to which country is Germany located?", save_interaction=True, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="This was the best answer I've ever seen", last_k=1, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="Wow the correctness of this answer blows my mind", last_k=1, ) + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() + return {"graph_snapshot": graph} - edges = graph[1] - for from_node, to_node, relationship_name, properties in edges: +@pytest.mark.asyncio +async def test_e2e_graph_vector_consistency(e2e_state): + """Graph and vector stores contain the same triplet edges.""" + assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"] + + +@pytest.mark.asyncio +async def test_e2e_retriever_contexts(e2e_state): + """All retrievers return non-empty, well-typed contexts.""" + contexts = e2e_state["contexts"] + + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + ]: + ctx = contexts[name] + assert isinstance(ctx, list), f"{name}: Context should be a list" + assert ctx, f"{name}: Context should not be empty" + ctx_text = await resolve_edges_to_text(ctx) + lower = ctx_text.lower() + assert "germany" in lower or "netherlands" in lower, ( + f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}" + ) + + triplet_ctx = contexts["triplet"] + assert isinstance(triplet_ctx, str), "triplet: Context should be a string" + assert triplet_ctx.strip(), "triplet: Context should not be empty" + + chunks_ctx = contexts["chunks"] + assert isinstance(chunks_ctx, list), "chunks: Context should be a list" + assert chunks_ctx, "chunks: Context should not be empty" + chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower() + assert "germany" in chunks_text or "netherlands" in chunks_text + + summaries_ctx = contexts["summaries"] + assert isinstance(summaries_ctx, list), "summaries: Context should be a list" + assert summaries_ctx, "summaries: Context should not be empty" + assert any(str(item.get("text", "")).strip() for item in summaries_ctx) + + rag_ctx = contexts["rag_completion"] + assert isinstance(rag_ctx, str), "rag_completion: Context should be a string" + assert rag_ctx.strip(), "rag_completion: Context should not be empty" + + temporal_ctx = contexts["temporal"] + assert isinstance(temporal_ctx, str), "temporal: Context should be a string" + assert temporal_ctx.strip(), "temporal: Context should not be empty" + + +@pytest.mark.asyncio +async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): + """Graph retriever triplets include sane vector_distance metadata.""" + for name, triplets in e2e_state["triplets"].items(): + assert isinstance(triplets, list), f"{name}: Triplets should be a list" + assert triplets, f"{name}: Triplets list should not be empty" + for edge in triplets: + assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" + distance = edge.attributes.get("vector_distance") + node1_distance = edge.node1.attributes.get("vector_distance") + node2_distance = edge.node2.attributes.get("vector_distance") + assert isinstance(distance, float), f"{name}: vector_distance should be float" + assert 0 <= distance <= 1 + assert 0 <= node1_distance <= 1 + assert 0 <= node2_distance <= 1 + + +@pytest.mark.asyncio +async def test_e2e_search_results_and_wrappers(e2e_state): + """Search returns expected shapes across search types and access modes.""" + from cognee.context_global_variables import backend_access_control_enabled + + sr = e2e_state["search_results"] + + # Completion-like search types: validate wrapper + content + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + "triplet_completion", + "rag_completion", + "temporal", + ]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert len(search_results) == 1, f"{name}: expected single-element list" + + if backend_access_control_enabled(): + wrapper = search_results[0] + assert isinstance(wrapper, dict), ( + f"{name}: expected wrapper dict in access control mode" + ) + assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper" + assert wrapper.get("dataset_name") == "test_dataset" + assert "graphs" in wrapper + text = wrapper["search_result"][0] + else: + text = search_results[0] + + assert isinstance(text, str) and text.strip() + assert "netherlands" in text.lower() + + # Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text + for name in ["chunks", "summaries"]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert search_results, f"{name}: should not be empty" + + first = search_results[0] + assert isinstance(first, dict), f"{name}: expected dict entries" + + payloads = search_results + if "search_result" in first and "text" not in first: + payloads = (first.get("search_result") or [None])[0] + + assert isinstance(payloads, list) and payloads + assert isinstance(payloads[0], dict) + assert str(payloads[0].get("text", "")).strip() + + +@pytest.mark.asyncio +async def test_e2e_graph_side_effects_and_node_fields(e2e_state): + """Search interactions create expected graph nodes/edges and required fields.""" + graph = e2e_state["graph_snapshot"] + nodes, edges = graph + + type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes) + edge_type_counts = Counter(edge_type[2] for edge_type in edges) + + assert type_counts.get("CogneeUserInteraction", 0) == 4 + assert type_counts.get("CogneeUserFeedback", 0) == 2 + assert type_counts.get("NodeSet", 0) == 2 + assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10 + assert edge_type_counts.get("gives_feedback_to", 0) == 2 + assert edge_type_counts.get("belongs_to_set", 0) >= 6 + + required_fields_user_interaction = {"question", "answer", "context"} + required_fields_feedback = {"feedback", "sentiment"} + + for node_id, data in nodes: + if data.get("type") == "CogneeUserInteraction": + assert required_fields_user_interaction.issubset(data.keys()) + for field in required_fields_user_interaction: + value = data[field] + assert isinstance(value, str) and value.strip() + + if data.get("type") == "CogneeUserFeedback": + assert required_fields_feedback.issubset(data.keys()) + for field in required_fields_feedback: + value = data[field] + assert isinstance(value, str) and value.strip() + + +@pytest.mark.asyncio +async def test_e2e_feedback_weight_calculation(feedback_state): + """Positive feedback increases used_graph_element_to_answer feedback_weight.""" + _nodes, edges = feedback_state["graph_snapshot"] + for _from_node, _to_node, relationship_name, properties in edges: if relationship_name == "used_graph_element_to_answer": assert properties["feedback_weight"] >= 6, ( "Feedback weight calculation is not correct, it should be more then 6." ) - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main())