chore: retriever test reorganization + adding new tests (smoke e2e) (STEP 1.5) (#1888)
<!-- .github/pull_request_template.md -->
This PR restructures the end-to-end tests for the multi-database search
layer to improve maintainability, consistency, and coverage across
supported Python versions and database settings.
Key Changes
-Migrates the existing E2E tests to pytest for a more standard and
extensible testing framework.
-Introduces pytest fixtures to centralize and reuse test setup logic.
-Implements proper event loop management to support multiple
asynchronous pytest tests reliably.
-Improves SQLAlchemy handling in tests, ensuring clean setup and
teardown of database state.
-Extends multi-database E2E test coverage across all supported Python
versions.
Benefits
-Cleaner and more modular test structure.
-Reduced duplication and clearer test intent through fixtures.
-More reliable async test execution.
-Better alignment with our supported Python version matrix.
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Tests**
* Expanded end-to-end test suite for the search database with
comprehensive setup/teardown, new session-scoped fixtures, and multiple
tests validating graph/vector consistency, retriever contexts, triplet
metadata, search result shapes, side effects, and feedback-weight
behavior.
* **Chores**
* CI updated to run matrixed test jobs across multiple Python versions
and standardize test execution for more consistent, parallelized runs.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
4e8845c117
commit
b4aaa7faef
2 changed files with 374 additions and 201 deletions
46
.github/workflows/search_db_tests.yml
vendored
46
.github/workflows/search_db_tests.yml
vendored
|
|
@ -11,12 +11,21 @@ on:
|
||||||
type: string
|
type: string
|
||||||
default: "all"
|
default: "all"
|
||||||
description: "Which vector databases to test (comma-separated list or 'all')"
|
description: "Which vector databases to test (comma-separated list or 'all')"
|
||||||
|
python-versions:
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: '["3.10", "3.11", "3.12", "3.13"]'
|
||||||
|
description: "Python versions to test (JSON array)"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run-kuzu-lance-sqlite-search-tests:
|
run-kuzu-lance-sqlite-search-tests:
|
||||||
name: Search test for Kuzu/LanceDB/Sqlite
|
name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -26,7 +35,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
@ -45,13 +54,16 @@ jobs:
|
||||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||||
VECTOR_DB_PROVIDER: 'lancedb'
|
VECTOR_DB_PROVIDER: 'lancedb'
|
||||||
DB_PROVIDER: 'sqlite'
|
DB_PROVIDER: 'sqlite'
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-neo4j-lance-sqlite-search-tests:
|
run-neo4j-lance-sqlite-search-tests:
|
||||||
name: Search test for Neo4j/LanceDB/Sqlite
|
name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -61,7 +73,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Setup Neo4j with GDS
|
- name: Setup Neo4j with GDS
|
||||||
uses: ./.github/actions/setup_neo4j
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
|
@ -88,12 +100,16 @@ jobs:
|
||||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-kuzu-pgvector-postgres-search-tests:
|
run-kuzu-pgvector-postgres-search-tests:
|
||||||
name: Search test for Kuzu/PGVector/Postgres
|
name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg17
|
image: pgvector/pgvector:pg17
|
||||||
|
|
@ -117,7 +133,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
|
|
@ -143,12 +159,16 @@ jobs:
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-neo4j-pgvector-postgres-search-tests:
|
run-neo4j-pgvector-postgres-search-tests:
|
||||||
name: Search test for Neo4j/PGVector/Postgres
|
name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg17
|
image: pgvector/pgvector:pg17
|
||||||
|
|
@ -172,7 +192,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
- name: Setup Neo4j with GDS
|
- name: Setup Neo4j with GDS
|
||||||
|
|
@ -205,4 +225,4 @@ jobs:
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
import pathlib
|
import pathlib
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||||
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def _reset_engines_and_prune() -> None:
|
||||||
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
"""Reset db engine caches and prune data/system.
|
||||||
|
|
||||||
|
Kept intentionally identical to the inlined setup logic to avoid event loop issues when
|
||||||
|
using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
|
||||||
|
"""
|
||||||
|
# Dispose of existing engines and clear caches to ensure fresh instances for each test
|
||||||
|
try:
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
# Dispose SQLAlchemy engine connection pool if it exists
|
||||||
|
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
||||||
|
await vector_engine.engine.dispose(close=True)
|
||||||
|
except Exception:
|
||||||
|
# Engine might not exist yet
|
||||||
|
pass
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
|
create_relational_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_graph_engine.cache_clear()
|
||||||
|
create_vector_engine.cache_clear()
|
||||||
|
create_relational_engine.cache_clear()
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
dataset_name = "test_dataset"
|
|
||||||
|
|
||||||
|
async def _seed_default_dataset(dataset_name: str) -> dict:
|
||||||
|
"""Add the shared test dataset contents and run cognify (same steps/order as before)."""
|
||||||
text_1 = """Germany is located in europe right next to the Netherlands"""
|
text_1 = """Germany is located in europe right next to the Netherlands"""
|
||||||
|
|
||||||
|
logger.info(f"Adding text data to dataset: {dataset_name}")
|
||||||
await cognee.add(text_1, dataset_name)
|
await cognee.add(text_1, dataset_name)
|
||||||
|
|
||||||
explanation_file_path_quantum = os.path.join(
|
explanation_file_path_quantum = os.path.join(
|
||||||
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Adding file data to dataset: {dataset_name}")
|
||||||
await cognee.add([explanation_file_path_quantum], dataset_name)
|
await cognee.add([explanation_file_path_quantum], dataset_name)
|
||||||
|
|
||||||
|
logger.info(f"Running cognify on dataset: {dataset_name}")
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"explanation_file_path_quantum": explanation_file_path_quantum,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Use a single asyncio event loop for this test module.
|
||||||
|
|
||||||
|
This helps avoid "Future attached to a different loop" when running multiple async
|
||||||
|
tests that share clients/engines.
|
||||||
|
"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
yield loop
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_environment():
|
||||||
|
"""Helper function to set up test environment with data, cognify, and triplet embeddings."""
|
||||||
|
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
||||||
|
|
||||||
|
dataset_name = "test_dataset"
|
||||||
|
logger.info("Starting test setup: pruning data and system")
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
state = await _seed_default_dataset(dataset_name=dataset_name)
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||||
|
|
||||||
|
logger.info("Creating triplet embeddings")
|
||||||
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
||||||
|
|
||||||
|
# Check if Triplet_text collection was created
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
|
||||||
|
logger.info(f"Triplet_text collection exists after creation: {has_collection}")
|
||||||
|
|
||||||
|
if has_collection:
|
||||||
|
collection = await vector_engine.get_collection("Triplet_text")
|
||||||
|
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
|
||||||
|
logger.info(f"Triplet_text collection row count: {count}")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_environment_for_feedback():
|
||||||
|
"""Helper function to set up test environment for feedback weight calculation test."""
|
||||||
|
dataset_name = "test_dataset"
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
return await _seed_default_dataset(dataset_name=dataset_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def e2e_state():
|
||||||
|
"""Compute E2E artifacts once; tests only assert.
|
||||||
|
|
||||||
|
This avoids repeating expensive setup and LLM calls across multiple tests.
|
||||||
|
"""
|
||||||
|
await setup_test_environment()
|
||||||
|
|
||||||
|
# --- Graph/vector engine consistency ---
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
_nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
collection = await vector_engine.search(
|
collection = await vector_engine.search(
|
||||||
query_text="Test", limit=None, collection_name="Triplet_text"
|
collection_name="Triplet_text", query_text="Test", limit=None
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(edges) == len(collection), (
|
# --- Retriever contexts ---
|
||||||
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
query = "Next to which country is Germany located?"
|
||||||
)
|
|
||||||
|
|
||||||
context_gk = await GraphCompletionRetriever().get_context(
|
contexts = {
|
||||||
query="Next to which country is Germany located?"
|
"graph_completion": await GraphCompletionRetriever().get_context(query=query),
|
||||||
)
|
"graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
|
||||||
context_gk_cot = await GraphCompletionCotRetriever().get_context(
|
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query=query
|
||||||
)
|
),
|
||||||
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
|
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query=query
|
||||||
)
|
),
|
||||||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
"chunks": await ChunksRetriever(top_k=5).get_context(query=query),
|
||||||
query="Next to which country is Germany located?"
|
"summaries": await SummariesRetriever(top_k=5).get_context(query=query),
|
||||||
)
|
"rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
|
||||||
context_triplet = await TripletRetriever().get_context(
|
"temporal": await TemporalRetriever(top_k=5).get_context(query=query),
|
||||||
query="Next to which country is Germany located?"
|
"triplet": await TripletRetriever().get_context(query=query),
|
||||||
)
|
}
|
||||||
|
|
||||||
for name, context in [
|
# --- Retriever triplets + vector distance validation ---
|
||||||
("GraphCompletionRetriever", context_gk),
|
triplets = {
|
||||||
("GraphCompletionCotRetriever", context_gk_cot),
|
"graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
|
||||||
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
"graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
|
||||||
("GraphSummaryCompletionRetriever", context_gk_sum),
|
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
|
||||||
]:
|
query=query
|
||||||
assert isinstance(context, list), f"{name}: Context should be a list"
|
),
|
||||||
assert len(context) > 0, f"{name}: Context should not be empty"
|
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
|
||||||
|
query=query
|
||||||
context_text = await resolve_edges_to_text(context)
|
),
|
||||||
lower = context_text.lower()
|
}
|
||||||
assert "germany" in lower or "netherlands" in lower, (
|
|
||||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
|
||||||
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
|
||||||
lower_triplet = context_triplet.lower()
|
|
||||||
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
|
||||||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, triplets in [
|
|
||||||
("GraphCompletionRetriever", triplets_gk),
|
|
||||||
("GraphCompletionCotRetriever", triplets_gk_cot),
|
|
||||||
("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
|
|
||||||
("GraphSummaryCompletionRetriever", triplets_gk_sum),
|
|
||||||
]:
|
|
||||||
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
|
||||||
assert triplets, f"{name}: Triplets list should not be empty"
|
|
||||||
for edge in triplets:
|
|
||||||
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
|
||||||
distance = edge.attributes.get("vector_distance")
|
|
||||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
|
||||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
|
||||||
assert isinstance(distance, float), (
|
|
||||||
f"{name}: vector_distance should be float, got {type(distance)}"
|
|
||||||
)
|
|
||||||
assert 0 <= distance <= 1, (
|
|
||||||
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
assert 0 <= node1_distance <= 1, (
|
|
||||||
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
assert 0 <= node2_distance <= 1, (
|
|
||||||
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# --- Search operations + graph side effects ---
|
||||||
completion_gk = await cognee.search(
|
completion_gk = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text="Where is germany located, next to which country?",
|
query_text="Where is germany located, next to which country?",
|
||||||
|
|
@ -164,6 +214,26 @@ async def main():
|
||||||
query_text="Next to which country is Germany located?",
|
query_text="Next to which country is Germany located?",
|
||||||
save_interaction=True,
|
save_interaction=True,
|
||||||
)
|
)
|
||||||
|
completion_chunks = await cognee.search(
|
||||||
|
query_type=SearchType.CHUNKS,
|
||||||
|
query_text="Germany",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_summaries = await cognee.search(
|
||||||
|
query_type=SearchType.SUMMARIES,
|
||||||
|
query_text="Germany",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_rag = await cognee.search(
|
||||||
|
query_type=SearchType.RAG_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_temporal = await cognee.search(
|
||||||
|
query_type=SearchType.TEMPORAL,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
|
|
@ -171,134 +241,217 @@ async def main():
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, search_results in [
|
# Snapshot after all E2E operations above (used by assertion-only tests).
|
||||||
("GRAPH_COMPLETION", completion_gk),
|
graph_snapshot = await (await get_graph_engine()).get_graph_data()
|
||||||
("GRAPH_COMPLETION_COT", completion_cot),
|
|
||||||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
|
||||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
|
||||||
("TRIPLET_COMPLETION", completion_triplet),
|
|
||||||
]:
|
|
||||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
||||||
assert len(search_results) == 1, (
|
|
||||||
f"{name}: expected single-element list, got {len(search_results)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from cognee.context_global_variables import backend_access_control_enabled
|
return {
|
||||||
|
"graph_edges": edges,
|
||||||
|
"triplet_collection": collection,
|
||||||
|
"vector_collection_edges_count": len(collection),
|
||||||
|
"graph_edges_count": len(edges),
|
||||||
|
"contexts": contexts,
|
||||||
|
"triplets": triplets,
|
||||||
|
"search_results": {
|
||||||
|
"graph_completion": completion_gk,
|
||||||
|
"graph_completion_cot": completion_cot,
|
||||||
|
"graph_completion_context_extension": completion_ext,
|
||||||
|
"graph_summary_completion": completion_sum,
|
||||||
|
"triplet_completion": completion_triplet,
|
||||||
|
"chunks": completion_chunks,
|
||||||
|
"summaries": completion_summaries,
|
||||||
|
"rag_completion": completion_rag,
|
||||||
|
"temporal": completion_temporal,
|
||||||
|
},
|
||||||
|
"graph_snapshot": graph_snapshot,
|
||||||
|
}
|
||||||
|
|
||||||
if backend_access_control_enabled():
|
|
||||||
text = search_results[0]["search_result"][0]
|
|
||||||
else:
|
|
||||||
text = search_results[0]
|
|
||||||
assert isinstance(text, str), f"{name}: element should be a string"
|
|
||||||
assert text.strip(), f"{name}: string should not be empty"
|
|
||||||
assert "netherlands" in text.lower(), (
|
|
||||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
@pytest_asyncio.fixture(scope="session")
|
||||||
graph = await graph_engine.get_graph_data()
|
async def feedback_state():
|
||||||
|
"""Feedback-weight scenario computed once (fresh environment)."""
|
||||||
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
await setup_test_environment_for_feedback()
|
||||||
|
|
||||||
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
|
||||||
|
|
||||||
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
|
||||||
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
|
||||||
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert there is exactly two CogneeUserFeedback nodes.
|
|
||||||
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
|
|
||||||
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert there is exactly two NodeSet.
|
|
||||||
assert type_counts.get("NodeSet", 0) == 2, (
|
|
||||||
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
|
|
||||||
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
|
|
||||||
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are exactly 2 'gives_feedback_to' edges.
|
|
||||||
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
|
|
||||||
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are at least 6 'belongs_to_set' edges.
|
|
||||||
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
|
|
||||||
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = graph[0]
|
|
||||||
|
|
||||||
required_fields_user_interaction = {"question", "answer", "context"}
|
|
||||||
required_fields_feedback = {"feedback", "sentiment"}
|
|
||||||
|
|
||||||
for node_id, data in nodes:
|
|
||||||
if data.get("type") == "CogneeUserInteraction":
|
|
||||||
assert required_fields_user_interaction.issubset(data.keys()), (
|
|
||||||
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in required_fields_user_interaction:
|
|
||||||
value = data[field]
|
|
||||||
assert isinstance(value, str) and value.strip(), (
|
|
||||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("type") == "CogneeUserFeedback":
|
|
||||||
assert required_fields_feedback.issubset(data.keys()), (
|
|
||||||
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in required_fields_feedback:
|
|
||||||
value = data[field]
|
|
||||||
assert isinstance(value, str) and value.strip(), (
|
|
||||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
|
|
||||||
await cognee.add(text_1, dataset_name)
|
|
||||||
|
|
||||||
await cognee.add([text], dataset_name)
|
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text="Next to which country is Germany located?",
|
query_text="Next to which country is Germany located?",
|
||||||
save_interaction=True,
|
save_interaction=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
query_text="This was the best answer I've ever seen",
|
query_text="This was the best answer I've ever seen",
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
query_text="Wow the correctness of this answer blows my mind",
|
query_text="Wow the correctness of this answer blows my mind",
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
graph = await graph_engine.get_graph_data()
|
graph = await graph_engine.get_graph_data()
|
||||||
|
return {"graph_snapshot": graph}
|
||||||
|
|
||||||
edges = graph[1]
|
|
||||||
|
|
||||||
for from_node, to_node, relationship_name, properties in edges:
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_graph_vector_consistency(e2e_state):
|
||||||
|
"""Graph and vector stores contain the same triplet edges."""
|
||||||
|
assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_retriever_contexts(e2e_state):
|
||||||
|
"""All retrievers return non-empty, well-typed contexts."""
|
||||||
|
contexts = e2e_state["contexts"]
|
||||||
|
|
||||||
|
for name in [
|
||||||
|
"graph_completion",
|
||||||
|
"graph_completion_cot",
|
||||||
|
"graph_completion_context_extension",
|
||||||
|
"graph_summary_completion",
|
||||||
|
]:
|
||||||
|
ctx = contexts[name]
|
||||||
|
assert isinstance(ctx, list), f"{name}: Context should be a list"
|
||||||
|
assert ctx, f"{name}: Context should not be empty"
|
||||||
|
ctx_text = await resolve_edges_to_text(ctx)
|
||||||
|
lower = ctx_text.lower()
|
||||||
|
assert "germany" in lower or "netherlands" in lower, (
|
||||||
|
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_ctx = contexts["triplet"]
|
||||||
|
assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
|
||||||
|
assert triplet_ctx.strip(), "triplet: Context should not be empty"
|
||||||
|
|
||||||
|
chunks_ctx = contexts["chunks"]
|
||||||
|
assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
|
||||||
|
assert chunks_ctx, "chunks: Context should not be empty"
|
||||||
|
chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
|
||||||
|
assert "germany" in chunks_text or "netherlands" in chunks_text
|
||||||
|
|
||||||
|
summaries_ctx = contexts["summaries"]
|
||||||
|
assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
|
||||||
|
assert summaries_ctx, "summaries: Context should not be empty"
|
||||||
|
assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
|
||||||
|
|
||||||
|
rag_ctx = contexts["rag_completion"]
|
||||||
|
assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
|
||||||
|
assert rag_ctx.strip(), "rag_completion: Context should not be empty"
|
||||||
|
|
||||||
|
temporal_ctx = contexts["temporal"]
|
||||||
|
assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
|
||||||
|
assert temporal_ctx.strip(), "temporal: Context should not be empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
||||||
|
"""Graph retriever triplets include sane vector_distance metadata."""
|
||||||
|
for name, triplets in e2e_state["triplets"].items():
|
||||||
|
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
||||||
|
assert triplets, f"{name}: Triplets list should not be empty"
|
||||||
|
for edge in triplets:
|
||||||
|
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||||
|
distance = edge.attributes.get("vector_distance")
|
||||||
|
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||||
|
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||||
|
assert isinstance(distance, float), f"{name}: vector_distance should be float"
|
||||||
|
assert 0 <= distance <= 1
|
||||||
|
assert 0 <= node1_distance <= 1
|
||||||
|
assert 0 <= node2_distance <= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_search_results_and_wrappers(e2e_state):
|
||||||
|
"""Search returns expected shapes across search types and access modes."""
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
sr = e2e_state["search_results"]
|
||||||
|
|
||||||
|
# Completion-like search types: validate wrapper + content
|
||||||
|
for name in [
|
||||||
|
"graph_completion",
|
||||||
|
"graph_completion_cot",
|
||||||
|
"graph_completion_context_extension",
|
||||||
|
"graph_summary_completion",
|
||||||
|
"triplet_completion",
|
||||||
|
"rag_completion",
|
||||||
|
"temporal",
|
||||||
|
]:
|
||||||
|
search_results = sr[name]
|
||||||
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||||
|
assert len(search_results) == 1, f"{name}: expected single-element list"
|
||||||
|
|
||||||
|
if backend_access_control_enabled():
|
||||||
|
wrapper = search_results[0]
|
||||||
|
assert isinstance(wrapper, dict), (
|
||||||
|
f"{name}: expected wrapper dict in access control mode"
|
||||||
|
)
|
||||||
|
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
|
||||||
|
assert wrapper.get("dataset_name") == "test_dataset"
|
||||||
|
assert "graphs" in wrapper
|
||||||
|
text = wrapper["search_result"][0]
|
||||||
|
else:
|
||||||
|
text = search_results[0]
|
||||||
|
|
||||||
|
assert isinstance(text, str) and text.strip()
|
||||||
|
assert "netherlands" in text.lower()
|
||||||
|
|
||||||
|
# Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
|
||||||
|
for name in ["chunks", "summaries"]:
|
||||||
|
search_results = sr[name]
|
||||||
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||||
|
assert search_results, f"{name}: should not be empty"
|
||||||
|
|
||||||
|
first = search_results[0]
|
||||||
|
assert isinstance(first, dict), f"{name}: expected dict entries"
|
||||||
|
|
||||||
|
payloads = search_results
|
||||||
|
if "search_result" in first and "text" not in first:
|
||||||
|
payloads = (first.get("search_result") or [None])[0]
|
||||||
|
|
||||||
|
assert isinstance(payloads, list) and payloads
|
||||||
|
assert isinstance(payloads[0], dict)
|
||||||
|
assert str(payloads[0].get("text", "")).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
|
||||||
|
"""Search interactions create expected graph nodes/edges and required fields."""
|
||||||
|
graph = e2e_state["graph_snapshot"]
|
||||||
|
nodes, edges = graph
|
||||||
|
|
||||||
|
type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
|
||||||
|
edge_type_counts = Counter(edge_type[2] for edge_type in edges)
|
||||||
|
|
||||||
|
assert type_counts.get("CogneeUserInteraction", 0) == 4
|
||||||
|
assert type_counts.get("CogneeUserFeedback", 0) == 2
|
||||||
|
assert type_counts.get("NodeSet", 0) == 2
|
||||||
|
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
|
||||||
|
assert edge_type_counts.get("gives_feedback_to", 0) == 2
|
||||||
|
assert edge_type_counts.get("belongs_to_set", 0) >= 6
|
||||||
|
|
||||||
|
required_fields_user_interaction = {"question", "answer", "context"}
|
||||||
|
required_fields_feedback = {"feedback", "sentiment"}
|
||||||
|
|
||||||
|
for node_id, data in nodes:
|
||||||
|
if data.get("type") == "CogneeUserInteraction":
|
||||||
|
assert required_fields_user_interaction.issubset(data.keys())
|
||||||
|
for field in required_fields_user_interaction:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip()
|
||||||
|
|
||||||
|
if data.get("type") == "CogneeUserFeedback":
|
||||||
|
assert required_fields_feedback.issubset(data.keys())
|
||||||
|
for field in required_fields_feedback:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_feedback_weight_calculation(feedback_state):
|
||||||
|
"""Positive feedback increases used_graph_element_to_answer feedback_weight."""
|
||||||
|
_nodes, edges = feedback_state["graph_snapshot"]
|
||||||
|
for _from_node, _to_node, relationship_name, properties in edges:
|
||||||
if relationship_name == "used_graph_element_to_answer":
|
if relationship_name == "used_graph_element_to_answer":
|
||||||
assert properties["feedback_weight"] >= 6, (
|
assert properties["feedback_weight"] >= 6, (
|
||||||
"Feedback weight calculation is not correct, it should be more then 6."
|
"Feedback weight calculation is not correct, it should be more then 6."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue