Merge branch 'dev' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-3

This commit is contained in:
hajdul88 2025-12-16 12:04:11 +01:00
commit 646894d7c5
53 changed files with 3033 additions and 601 deletions

View file

@ -237,6 +237,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_database_handler.py
test-dataset-database-deletion:
name: Test dataset database deletion in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run dataset databases deletion test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_delete.py
test-permissions:
name: Test permissions with different situations in Cognee
runs-on: ubuntu-22.04

154
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,154 @@
name: release.yml
on:
workflow_dispatch:
inputs:
flavour:
required: true
default: dev
type: choice
options:
- dev
- main
description: Dev or Main release
test_mode:
required: true
type: boolean
description: Aka Dry Run. If true, it won't affect public indices or repositories
jobs:
release-github:
name: Create GitHub Release from ${{ inputs.flavour }}
outputs:
tag: ${{ steps.create_tag.outputs.tag }}
version: ${{ steps.create_tag.outputs.version }}
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Create and push git tag
id: create_tag
env:
TEST_MODE: ${{ inputs.test_mode }}
run: |
VERSION="$(uv version --short)"
TAG="v${VERSION}"
echo "Tag to create: ${TAG}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
if [ "$TEST_MODE" = "false" ]; then
git tag "${TAG}"
git push origin "${TAG}"
else
echo "Test mode is enabled. Skipping tag creation and push."
fi
- name: Create GitHub Release
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ steps.create_tag.outputs.tag }}
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
release-pypi-package:
needs: release-github
name: Release PyPI Package from ${{ inputs.flavour }}
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install Python
run: uv python install
- name: Install dependencies
run: uv sync --locked --all-extras
- name: Build distributions
run: uv build
- name: Publish ${{ inputs.flavour }} release to TestPyPI
if: ${{ inputs.test_mode }}
env:
UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }}
run: uv publish --publish-url https://test.pypi.org/legacy/
- name: Publish ${{ inputs.flavour }} release to PyPI
if: ${{ !inputs.test_mode }}
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: uv publish
release-docker-image:
needs: release-github
name: Release Docker Image from ${{ inputs.flavour }}
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and push Dev Docker Image
if: ${{ inputs.flavour == 'dev' }}
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ !inputs.test_mode }}
tags: cognee/cognee:${{ needs.release-github.outputs.version }}
labels: |
version=${{ needs.release-github.outputs.version }}
flavour=${{ inputs.flavour }}
cache-from: type=registry,ref=cognee/cognee:buildcache
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
- name: Build and push Main Docker Image
if: ${{ inputs.flavour == 'main' }}
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ !inputs.test_mode }}
tags: |
cognee/cognee:${{ needs.release-github.outputs.version }}
cognee/cognee:latest
labels: |
version=${{ needs.release-github.outputs.version }}
flavour=${{ inputs.flavour }}
cache-from: type=registry,ref=cognee/cognee:buildcache
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max

View file

@ -11,12 +11,21 @@ on:
type: string
default: "all"
description: "Which vector databases to test (comma-separated list or 'all')"
python-versions:
required: false
type: string
default: '["3.10", "3.11", "3.12", "3.13"]'
description: "Python versions to test (JSON array)"
jobs:
run-kuzu-lance-sqlite-search-tests:
name: Search test for Kuzu/LanceDB/Sqlite
name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
fail-fast: false
steps:
- name: Check out
uses: actions/checkout@v4
@ -26,7 +35,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
python-version: ${{ matrix.python-version }}
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -45,13 +54,16 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
run: uv run python ./cognee/tests/test_search_db.py
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-neo4j-lance-sqlite-search-tests:
name: Search test for Neo4j/LanceDB/Sqlite
name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
fail-fast: false
steps:
- name: Check out
uses: actions/checkout@v4
@ -61,7 +73,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
python-version: ${{ matrix.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
@ -88,12 +100,16 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_search_db.py
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-kuzu-pgvector-postgres-search-tests:
name: Search test for Kuzu/PGVector/Postgres
name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
fail-fast: false
services:
postgres:
image: pgvector/pgvector:pg17
@ -117,7 +133,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
python-version: ${{ matrix.python-version }}
extra-dependencies: "postgres"
- name: Dependencies already installed
@ -143,12 +159,16 @@ jobs:
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_search_db.py
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-neo4j-pgvector-postgres-search-tests:
name: Search test for Neo4j/PGVector/Postgres
name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
fail-fast: false
services:
postgres:
image: pgvector/pgvector:pg17
@ -172,7 +192,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
python-version: ${{ matrix.python-version }}
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
@ -205,4 +225,4 @@ jobs:
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_search_db.py
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO

View file

@ -84,3 +84,93 @@ jobs:
EMBEDDING_DIMENSIONS: "3072"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
test-bedrock-api-key:
name: Run Bedrock API Key Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Run Bedrock API Key Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_REGION_NAME: "eu-west-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
test-bedrock-aws-credentials:
name: Run Bedrock AWS Credentials Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Run Bedrock AWS Credentials Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_REGION_NAME: "eu-west-1"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
test-bedrock-aws-profile:
name: Run Bedrock AWS Profile Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Configure AWS Profile
run: |
mkdir -p ~/.aws
cat > ~/.aws/credentials << EOF
[bedrock-test]
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
EOF
- name: Run Bedrock AWS Profile Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_PROFILE_NAME: "bedrock-test"
AWS_REGION_NAME: "eu-west-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py

View file

@ -3,7 +3,7 @@
Test client for Cognee MCP Server functionality.
This script tests all the tools and functions available in the Cognee MCP server,
including cognify, codify, search, prune, status checks, and utility functions.
including cognify, search, prune, status checks, and utility functions.
Usage:
# Set your OpenAI API key first
@ -23,6 +23,7 @@ import tempfile
import time
from contextlib import asynccontextmanager
from cognee.shared.logging_utils import setup_logging
from logging import ERROR, INFO
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@ -35,7 +36,7 @@ from src.server import (
load_class,
)
# Set timeout for cognify/codify to complete in
# Set timeout for cognify to complete in
TIMEOUT = 5 * 60 # 5 min in seconds
@ -151,12 +152,9 @@ DEBUG = True
expected_tools = {
"cognify",
"codify",
"search",
"prune",
"cognify_status",
"codify_status",
"cognee_add_developer_rules",
"list_data",
"delete",
}
@ -247,106 +245,6 @@ DEBUG = True
}
print(f"{test_name} test failed: {e}")
async def test_codify(self):
"""Test the codify functionality using MCP client."""
print("\n🧪 Testing codify functionality...")
try:
async with self.mcp_server_session() as session:
codify_result = await session.call_tool(
"codify", arguments={"repo_path": self.test_repo_dir}
)
start = time.time() # mark the start
while True:
try:
# Wait a moment
await asyncio.sleep(5)
# Check if codify processing is finished
status_result = await session.call_tool("codify_status", arguments={})
if hasattr(status_result, "content") and status_result.content:
status_text = (
status_result.content[0].text
if status_result.content
else str(status_result)
)
else:
status_text = str(status_result)
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
break
elif time.time() - start > TIMEOUT:
raise TimeoutError("Codify did not complete in 5min")
except DatabaseNotCreatedError:
if time.time() - start > TIMEOUT:
raise TimeoutError("Database was not created in 5min")
self.test_results["codify"] = {
"status": "PASS",
"result": codify_result,
"message": "Codify executed successfully",
}
print("✅ Codify test passed")
except Exception as e:
self.test_results["codify"] = {
"status": "FAIL",
"error": str(e),
"message": "Codify test failed",
}
print(f"❌ Codify test failed: {e}")
async def test_cognee_add_developer_rules(self):
"""Test the cognee_add_developer_rules functionality using MCP client."""
print("\n🧪 Testing cognee_add_developer_rules functionality...")
try:
async with self.mcp_server_session() as session:
result = await session.call_tool(
"cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
)
start = time.time() # mark the start
while True:
try:
# Wait a moment
await asyncio.sleep(5)
# Check if developer rule cognify processing is finished
status_result = await session.call_tool("cognify_status", arguments={})
if hasattr(status_result, "content") and status_result.content:
status_text = (
status_result.content[0].text
if status_result.content
else str(status_result)
)
else:
status_text = str(status_result)
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
break
elif time.time() - start > TIMEOUT:
raise TimeoutError(
"Cognify of developer rules did not complete in 5min"
)
except DatabaseNotCreatedError:
if time.time() - start > TIMEOUT:
raise TimeoutError("Database was not created in 5min")
self.test_results["cognee_add_developer_rules"] = {
"status": "PASS",
"result": result,
"message": "Developer rules addition executed successfully",
}
print("✅ Developer rules test passed")
except Exception as e:
self.test_results["cognee_add_developer_rules"] = {
"status": "FAIL",
"error": str(e),
"message": "Developer rules test failed",
}
print(f"❌ Developer rules test failed: {e}")
async def test_search_functionality(self):
"""Test the search functionality with different search types using MCP client."""
print("\n🧪 Testing search functionality...")
@ -359,7 +257,11 @@ DEBUG = True
# Go through all Cognee search types
for search_type in SearchType:
# Don't test these search types
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
if search_type in [
SearchType.NATURAL_LANGUAGE,
SearchType.CYPHER,
SearchType.TRIPLET_COMPLETION,
]:
break
try:
async with self.mcp_server_session() as session:
@ -681,9 +583,6 @@ class TestModel:
test_name="Cognify2",
)
await self.test_codify()
await self.test_cognee_add_developer_rules()
# Test list_data and delete functionality
await self.test_list_data()
await self.test_delete()
@ -739,7 +638,5 @@ async def main():
if __name__ == "__main__":
from logging import ERROR
logger = setup_logging(log_level=ERROR)
asyncio.run(main())

View file

@ -155,7 +155,7 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
- LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password

View file

@ -53,6 +53,7 @@ async def cognify(
custom_prompt: Optional[str] = None,
temporal_cognify: bool = False,
data_per_batch: int = 20,
**kwargs,
):
"""
Transform ingested data into a structured knowledge graph.
@ -223,6 +224,7 @@ async def cognify(
config=config,
custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch,
**kwargs,
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config: Config = None,
custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100,
**kwargs,
) -> list[Task]:
if config is None:
ontology_config = get_ontology_env_config()
@ -288,6 +291,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config=config,
custom_prompt=custom_prompt,
task_config={"batch_size": chunks_per_batch},
**kwargs,
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,

View file

@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO):
default="", description="Custom prompt for entity extraction and graph generation"
)
ontology_key: Optional[List[str]] = Field(
default=None, description="Reference to one or more previously uploaded ontologies"
default=None,
examples=[[]],
description="Reference to one or more previously uploaded ontologies",
)

View file

@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
},
)
from cognee.modules.data.methods import get_dataset, delete_dataset
from cognee.modules.data.methods import delete_dataset
dataset = await get_dataset(user.id, dataset_id)
dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
if dataset is None:
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
await delete_dataset(dataset)
await delete_dataset(dataset[0])
@router.delete(
"/{dataset_id}/data/{data_id}",

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
from fastapi.responses import JSONResponse
from typing import Optional, List
@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter:
@router.post("", response_model=dict)
async def upload_ontology(
request: Request,
ontology_key: str = Form(...),
ontology_file: List[UploadFile] = File(...),
descriptions: Optional[str] = Form(None),
ontology_file: UploadFile = File(...),
description: Optional[str] = Form(None),
user: User = Depends(get_authenticated_user),
):
"""
Upload ontology files with their respective keys for later use in cognify operations.
Supports both single and multiple file uploads:
- Single file: ontology_key=["key"], ontology_file=[file]
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
Upload a single ontology file for later use in cognify operations.
## Request Parameters
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
- **ontology_file** (List[UploadFile]): OWL format ontology files
- **descriptions** (Optional[str]): JSON array string of optional descriptions
- **ontology_key** (str): User-defined identifier for the ontology.
- **ontology_file** (UploadFile): Single OWL format ontology file
- **description** (Optional[str]): Optional description for the ontology.
## Response
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
## Error Codes
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
- **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter:
)
try:
import json
# Enforce: exactly one uploaded file for "ontology_file"
form = await request.form()
uploaded_files = form.getlist("ontology_file")
if len(uploaded_files) != 1:
raise ValueError("Only one ontology_file is allowed")
ontology_keys = json.loads(ontology_key)
description_list = json.loads(descriptions) if descriptions else None
if ontology_key.strip().startswith(("[", "{")):
raise ValueError("ontology_key must be a string")
if description is not None and description.strip().startswith(("[", "{")):
raise ValueError("description must be a string")
if not isinstance(ontology_keys, list):
raise ValueError("ontology_key must be a JSON array")
results = await ontology_service.upload_ontologies(
ontology_keys, ontology_file, user, description_list
result = await ontology_service.upload_ontology(
ontology_key=ontology_key,
file=ontology_file,
user=user,
description=description,
)
return {
@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter:
"uploaded_at": result.uploaded_at,
"description": result.description,
}
for result in results
]
}
except (json.JSONDecodeError, ValueError) as e:
except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})

View file

@ -1,2 +1,4 @@
from .get_or_create_dataset_database import get_or_create_dataset_database
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
from .get_vector_dataset_database_handler import get_vector_dataset_database_handler

View file

@ -0,0 +1,10 @@
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
return handler

View file

@ -0,0 +1,10 @@
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
return handler

View file

@ -1,24 +1,12 @@
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
get_graph_dataset_database_handler,
)
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
get_vector_dataset_database_handler,
)
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def resolve_dataset_database_connection_info(
dataset_database: DatasetDatabase,
) -> DatasetDatabase:
@ -31,6 +19,12 @@ async def resolve_dataset_database_connection_info(
Returns:
DatasetDatabase instance with resolved connection info
"""
dataset_database = await _get_vector_db_connection_info(dataset_database)
dataset_database = await _get_graph_db_connection_info(dataset_database)
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
dataset_database = await vector_dataset_database_handler[
"handler_instance"
].resolve_dataset_connection_info(dataset_database)
dataset_database = await graph_dataset_database_handler[
"handler_instance"
].resolve_dataset_connection_info(dataset_database)
return dataset_database

View file

@ -9,6 +9,8 @@ class S3Config(BaseSettings):
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
aws_profile_name: Optional[str] = None
aws_bedrock_runtime_endpoint: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -11,7 +11,7 @@ class LLMGateway:
@staticmethod
def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
@ -31,7 +31,10 @@ class LLMGateway:
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
text_input=text_input,
system_prompt=system_prompt,
response_model=response_model,
**kwargs,
)
@staticmethod

View file

@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
async def extract_content_graph(
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
):
if custom_prompt:
system_prompt = custom_prompt
@ -30,7 +30,7 @@ async def extract_content_graph(
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model
content, system_prompt, response_model, **kwargs
)
return content_graph

View file

@ -52,7 +52,7 @@ class AnthropicAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -0,0 +1,5 @@
"""Bedrock LLM adapter module."""
from .adapter import BedrockAdapter
__all__ = ["BedrockAdapter"]

View file

@ -0,0 +1,153 @@
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
MissingSystemPromptPathError,
)
from cognee.infrastructure.files.storage.s3_config import get_s3_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe
observe = get_observe()
class BedrockAdapter(LLMInterface):
"""
Adapter for AWS Bedrock API with support for three authentication methods:
1. API Key (Bearer Token)
2. AWS Credentials (access key + secret key)
3. AWS Profile (boto3 credential chain)
"""
name = "Bedrock"
model: str
api_key: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
def __init__(
self,
model: str,
api_key: str = None,
max_completion_tokens: int = 16384,
streaming: bool = False,
instructor_mode: str = None,
):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(litellm.completion)
self.model = model
self.api_key = api_key
self.max_completion_tokens = max_completion_tokens
self.streaming = streaming
def _create_bedrock_request(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> dict:
"""Create Bedrock request with authentication."""
request_params = {
"model": self.model,
"custom_llm_provider": "bedrock",
"drop_params": True,
"messages": [
{"role": "user", "content": text_input},
{"role": "system", "content": system_prompt},
],
"response_model": response_model,
"max_retries": self.MAX_RETRIES,
"max_completion_tokens": self.max_completion_tokens,
"stream": self.streaming,
}
s3_config = get_s3_config()
# Add authentication parameters
if self.api_key:
request_params["api_key"] = self.api_key
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
if s3_config.aws_session_token:
request_params["aws_session_token"] = s3_config.aws_session_token
elif s3_config.aws_profile_name:
request_params["aws_profile_name"] = s3_config.aws_profile_name
if s3_config.aws_region:
request_params["aws_region_name"] = s3_config.aws_region
# Add optional parameters
if s3_config.aws_bedrock_runtime_endpoint:
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
return request_params
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API."""
try:
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return await self.aclient.chat.completions.create(**request_params)
except (
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
@observe
@sleep_and_retry_sync()
@rate_limit_sync
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API (synchronous)."""
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return self.client.chat.completions.create(**request_params)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -80,7 +80,7 @@ class GeminiAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -80,7 +80,7 @@ class GenericAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -24,6 +24,7 @@ class LLMProvider(Enum):
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider.
- BEDROCK: Represents the AWS Bedrock provider.
"""
OPENAI = "openai"
@ -32,6 +33,7 @@ class LLMProvider(Enum):
CUSTOM = "custom"
GEMINI = "gemini"
MISTRAL = "mistral"
BEDROCK = "bedrock"
def get_llm_client(raise_api_key_error: bool = True):
@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
if llm_config.llm_api_key is None and raise_api_key_error:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.BEDROCK:
# if llm_config.llm_api_key is None and raise_api_key_error:
# raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
BedrockAdapter,
)
return BedrockAdapter(
model=llm_config.llm_model,
api_key=llm_config.llm_api_key,
max_completion_tokens=max_completion_tokens,
streaming=llm_config.llm_streaming,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else:
raise UnsupportedLLMProviderError(provider)

View file

@ -69,7 +69,7 @@ class MistralAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from the user query.

View file

@ -76,7 +76,7 @@ class OllamaAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a structured output from the LLM using the provided text and system prompt.
@ -123,7 +123,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str:
async def create_transcript(self, input_file: str, **kwargs) -> str:
"""
Generate an audio transcript from a user query.
@ -162,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str:
async def transcribe_image(self, input_file: str, **kwargs) -> str:
"""
Transcribe content from an image using base64 encoding.

View file

@ -112,7 +112,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@ -154,6 +154,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
except (
ContentFilterFinishReasonError,
@ -180,6 +181,7 @@ class OpenAIAdapter(LLMInterface):
# api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
except (
ContentFilterFinishReasonError,
@ -205,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@ -245,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
@retry(
@ -254,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
async def create_transcript(self, input, **kwargs):
"""
Generate an audio transcript from a user query.
@ -281,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint,
api_version=self.api_version,
max_retries=self.MAX_RETRIES,
**kwargs,
)
return transcription
@ -292,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel:
async def transcribe_image(self, input, **kwargs) -> BaseModel:
"""
Generate a transcription of an image from a user query.
@ -337,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
**kwargs,
)

View file

@ -5,6 +5,10 @@ from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.utils import (
get_graph_dataset_database_handler,
get_vector_dataset_database_handler,
)
from cognee.shared.cache import delete_cache
from cognee.modules.users.models import DatasetDatabase
from cognee.shared.logging_utils import get_logger
@ -13,22 +17,13 @@ logger = get_logger()
async def prune_graph_databases():
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.graph_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the graph database
for data_item in data:
await _prune_graph_db(data_item)
for dataset_database in dataset_databases:
handler = get_graph_dataset_database_handler(dataset_database)
await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
@ -38,22 +33,13 @@ async def prune_graph_databases():
async def prune_vector_databases():
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.vector_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the vector database
for data_item in data:
await _prune_vector_db(data_item)
for dataset_database in dataset_databases:
handler = get_vector_dataset_database_handler(dataset_database)
await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",

View file

@ -1,8 +1,34 @@
from cognee.modules.users.models import DatasetDatabase
from sqlalchemy import select
from cognee.modules.data.models import Dataset
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
get_vector_dataset_database_handler,
)
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
get_graph_dataset_database_handler,
)
from cognee.infrastructure.databases.relational import get_relational_engine
async def delete_dataset(dataset: Dataset):
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = select(DatasetDatabase).where(
DatasetDatabase.dataset_id == dataset.id,
)
dataset_database: DatasetDatabase = await session.scalar(stmt)
if dataset_database:
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
await graph_dataset_database_handler["handler_instance"].delete_dataset(
dataset_database
)
await vector_dataset_database_handler["handler_instance"].delete_dataset(
dataset_database
)
# TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
# This blocks recreation of the dataset with the same name and data after deletion as
# it's marked as completed and will be just skipped even though it's empty.
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)

View file

@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.top_k = top_k if top_k is not None else 5
self.system_prompt = system_prompt
async def get_context(self, query: str) -> str:

View file

@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {
k: v for k, v in value.items() if k in attributes and v is not None
}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1

View file

@ -16,6 +16,7 @@ class ModelName(Enum):
anthropic = "anthropic"
gemini = "gemini"
mistral = "mistral"
bedrock = "bedrock"
class LLMConfig(BaseModel):
@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
"value": "mistral",
"label": "Mistral",
},
{
"value": "bedrock",
"label": "Bedrock",
},
]
return SettingsDict.model_validate(
@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
"label": "Mistral Large 2.1",
},
],
"bedrock": [
{
"value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
"label": "Claude 4.5 Sonnet",
},
{
"value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
"label": "Claude 4.5 Haiku",
},
{
"value": "eu.amazon.nova-lite-v1:0",
"label": "Amazon Nova Lite",
},
],
},
},
vector_db={

View file

@ -97,6 +97,7 @@ async def extract_graph_from_data(
graph_model: Type[BaseModel],
config: Config = None,
custom_prompt: Optional[str] = None,
**kwargs,
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@ -111,7 +112,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather(
*[
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
for chunk in data_chunks
]
)

View file

@ -1,5 +1,6 @@
from typing import AsyncGenerator, Dict, Any, List, Optional
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.modules.engine.utils import generate_node_id
from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.infrastructure.engine import DataPoint
@ -155,7 +156,12 @@ def _process_single_triplet(
embeddable_text = f"{start_node_text}-{relationship_text}-{end_node_text}".strip()
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
relationship_name = relationship.get("relationship_name", "")
triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
triplet_obj = Triplet(
id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
)
return triplet_obj, None

View file

@ -0,0 +1,252 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import List
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_simple():
"""Set up a clean test environment with simple chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_complex():
"""Set up a clean test environment with complex chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
"""Integration test: verify ChunksRetriever can retrieve multiple chunks."""
retriever = ChunksRetriever()
context = await retriever.get_context("Steve")
assert isinstance(context, list), "Context should be a list"
assert len(context) > 0, "Context should not be empty"
assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
"Failed to get Steve Rodger chunk"
)
@pytest.mark.asyncio
async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
"""Integration test: verify ChunksRetriever respects top_k parameter."""
retriever = ChunksRetriever(top_k=2)
context = await retriever.get_context("Employee")
assert isinstance(context, list), "Context should be a list"
assert len(context) <= 2, "Should respect top_k limit"
@pytest.mark.asyncio
async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
"""Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
retriever = ChunksRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify ChunksRetriever handles empty graph correctly."""
retriever = ChunksRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"

View file

@ -0,0 +1,268 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma", description="Figma is a company")
company2 = Company(name="Canva", description="Canvas is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
person2 = Person(
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
)
person3 = Person(
name="Jason Statham",
description="This is description about Jason Statham",
works_for=company1,
)
person4 = Person(
name="Mike Broski",
description="This is description about Mike Broski",
works_for=company2,
)
person5 = Person(
name="Christina Mayer",
description="This is description about Christina Mayer",
works_for=company2,
)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionRetriever can retrieve context (simple)."""
retriever = GraphCompletionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
assert "Connections:" in context, "Missing 'Connections:' section in context"
# --- Nodes headers ---
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
assert "Node: Figma" in context, "Missing node header for Figma"
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
assert "Node: Canva" in context, "Missing node header for Canva"
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
# --- Node contents ---
assert (
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
in context
), "Description block for Steve Rodger altered"
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
"Description block for Figma altered"
)
assert (
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
in context
), "Description block for Ike Loma altered"
assert (
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
in context
), "Description block for Jason Statham altered"
assert (
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
in context
), "Description block for Mike Broski altered"
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
"Description block for Canva altered"
)
assert (
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
in context
), "Description block for Christina Mayer altered"
# --- Connections ---
assert "Steve Rodger --[works_for]--> Figma" in context, (
"Connection Steve Rodger→Figma missing or changed"
)
assert "Ike Loma --[works_for]--> Figma" in context, (
"Connection Ike Loma→Figma missing or changed"
)
assert "Jason Statham --[works_for]--> Figma" in context, (
"Connection Jason Statham→Figma missing or changed"
)
assert "Mike Broski --[works_for]--> Canva" in context, (
"Connection Mike Broski→Canva missing or changed"
)
assert "Christina Mayer --[works_for]--> Canva" in context, (
"Connection Christina Mayer→Canva missing or changed"
)
@pytest.mark.asyncio
async def test_graph_completion_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionRetriever can retrieve context (complex)."""
retriever = GraphCompletionRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionRetriever handles empty graph correctly."""
retriever = GraphCompletionRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
@pytest.mark.asyncio
async def test_graph_completion_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionRetriever get_triplets handles empty graph."""
retriever = GraphCompletionRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,226 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_extension_context_simple"
)
data_directory_path = str(
base_dir / ".data_storage/test_graph_completion_extension_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_extension_context_complex"
)
data_directory_path = str(
base_dir / ".data_storage/test_graph_completion_extension_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple)."""
retriever = GraphCompletionContextExtensionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex)."""
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context = await resolve_edges_to_text(
await retriever.get_context("Who works at Figma and drives Tesla?")
)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly."""
retriever = GraphCompletionContextExtensionRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph."""
retriever = GraphCompletionContextExtensionRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,218 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_cot_context_simple"
)
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_cot_context_complex"
)
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (simple)."""
retriever = GraphCompletionCotRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (complex)."""
retriever = GraphCompletionCotRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionCotRetriever handles empty graph correctly."""
retriever = GraphCompletionCotRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph."""
retriever = GraphCompletionCotRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,254 @@
import os
from typing import List
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_simple():
"""Set up a clean test environment with simple chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_complex():
"""Set up a clean test environment with complex chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple):
"""Integration test: verify CompletionRetriever can retrieve context (simple)."""
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert isinstance(context, str), "Context should be a string"
assert "Mike Broski" in context, "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple):
"""Integration test: verify CompletionRetriever can retrieve context from multiple chunks."""
retriever = CompletionRetriever()
context = await retriever.get_context("Steve")
assert isinstance(context, str), "Context should be a string"
assert "Steve Rodger" in context, "Failed to get Steve Rodger"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex):
"""Integration test: verify CompletionRetriever can retrieve context (complex)."""
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify CompletionRetriever handles empty graph correctly."""
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"

View file

@ -1,9 +1,9 @@
import asyncio
import pytest
import cognee
import pathlib
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion():
_assert_structured_answer(structured_answer)
class TestStructuredOutputCompletion:
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
@pytest_asyncio.fixture
async def setup_test_environment():
"""Set up a clean test environment with graph and document data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion")
data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
except Exception:
pass
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()
@pytest.mark.asyncio
async def test_get_structured_completion(setup_test_environment):
"""Integration test: verify structured output completion for all retrievers."""
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -0,0 +1,184 @@
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
@pytest_asyncio.fixture
async def setup_test_environment_with_summaries():
"""Set up a clean test environment with summaries."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context")
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without summaries."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty")
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_summaries_retriever_context(setup_test_environment_with_summaries):
"""Integration test: verify SummariesRetriever can retrieve summary context."""
retriever = SummariesRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert isinstance(context, list), "Context should be a list"
assert len(context) > 0, "Context should not be empty"
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify SummariesRetriever handles empty graph correctly."""
retriever = SummariesRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"

View file

@ -0,0 +1,306 @@
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.engine.models.Event import Event
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.modules.engine.models.Interval import Interval
@pytest_asyncio.fixture
async def setup_test_environment_with_events():
"""Set up a clean test environment with temporal events."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
# Create timestamps for events
timestamp1 = Timestamp(
time_at=1609459200, # 2021-01-01 00:00:00
year=2021,
month=1,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-01-01T00:00:00",
)
timestamp2 = Timestamp(
time_at=1612137600, # 2021-02-01 00:00:00
year=2021,
month=2,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-02-01T00:00:00",
)
timestamp3 = Timestamp(
time_at=1614556800, # 2021-03-01 00:00:00
year=2021,
month=3,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-03-01T00:00:00",
)
timestamp4 = Timestamp(
time_at=1625097600, # 2021-07-01 00:00:00
year=2021,
month=7,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-07-01T00:00:00",
)
timestamp5 = Timestamp(
time_at=1633046400, # 2021-10-01 00:00:00
year=2021,
month=10,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-10-01T00:00:00",
)
# Create interval for event spanning multiple timestamps
interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
# Create events with timestamps
event1 = Event(
name="Project Alpha Launch",
description="Launched Project Alpha at the beginning of 2021",
at=timestamp1,
location="San Francisco",
)
event2 = Event(
name="Team Meeting",
description="Monthly team meeting discussing Q1 goals",
during=interval1,
location="New York",
)
event3 = Event(
name="Product Release",
description="Released new product features in July",
at=timestamp4,
location="Remote",
)
event4 = Event(
name="Company Retreat",
description="Annual company retreat in October",
at=timestamp5,
location="Lake Tahoe",
)
entities = [event1, event2, event3, event4]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_graph_data():
"""Set up a clean test environment with graph data (for fallback to triplets)."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma", description="Figma is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
entities = [company1, person1]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve events within time range."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("What happened in January 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Project Alpha" in context or "Launch" in context, (
"Should retrieve Project Alpha Launch event from January 2021"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve events at specific time."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("What happened in July 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Product Release" in context or "July" in context, (
"Should retrieve Product Release event from July 2021"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_fallback_to_triplets(
setup_test_environment_with_graph_data,
):
"""Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("Who works at Figma?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Steve" in context or "Figma" in context, (
"Should retrieve graph data via triplet search fallback"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
"""Integration test: verify TemporalRetriever handles empty graph correctly."""
retriever = TemporalRetriever()
context = await retriever.get_context("What happened?")
assert isinstance(context, str), "Context should be a string"
assert len(context) >= 0, "Context should be a string (possibly empty)"
@pytest.mark.asyncio
async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can generate completions."""
retriever = TemporalRetriever()
completion = await retriever.get_completion("What happened in January 2021?")
assert isinstance(completion, list), "Completion should be a list"
assert len(completion) > 0, "Completion should not be empty"
assert all(isinstance(item, str) and item.strip() for item in completion), (
"Completion items should be non-empty strings"
)
@pytest.mark.asyncio
async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
"""Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
retriever = TemporalRetriever()
completion = await retriever.get_completion("Who works at Figma?")
assert isinstance(completion, list), "Completion should be a list"
assert len(completion) > 0, "Completion should not be empty"
assert all(isinstance(item, str) and item.strip() for item in completion), (
"Completion items should be non-empty strings"
)
@pytest.mark.asyncio
async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever respects top_k parameter."""
retriever = TemporalRetriever(top_k=2)
context = await retriever.get_context("What happened in 2021?")
assert isinstance(context, str), "Context should be a string"
separator_count = context.count("#####################")
assert separator_count <= 1, "Should respect top_k limit of 2 events"
@pytest.mark.asyncio
async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve multiple events."""
retriever = TemporalRetriever(top_k=10)
context = await retriever.get_context("What events occurred in 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert (
"Project Alpha" in context
or "Team Meeting" in context
or "Product Release" in context
or "Company Retreat" in context
), "Should retrieve at least one event from 2021"

View file

@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip
context = await retriever.get_context("Alice")
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
@pytest.mark.asyncio
async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets):
"""Integration test: verify TripletRetriever can retrieve multiple triplets."""
retriever = TripletRetriever(top_k=5)
context = await retriever.get_context("Bob")
assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, (
"Failed to get Bob-related triplets"
)
@pytest.mark.asyncio
async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets):
"""Integration test: verify TripletRetriever respects top_k parameter."""
retriever = TripletRetriever(top_k=1)
context = await retriever.get_context("Alice")
assert isinstance(context, str), "Context should be a string"
@pytest.mark.asyncio
async def test_triplet_retriever_context_empty(setup_test_environment_empty):
"""Integration test: verify TripletRetriever handles empty graph correctly."""
await setup()
retriever = TripletRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Alice")

View file

@ -148,8 +148,8 @@ class TestCogneeServerStart(unittest.TestCase):
headers=headers,
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={
"ontology_key": json.dumps([ontology_key]),
"description": json.dumps(["Test ontology"]),
"ontology_key": ontology_key,
"description": "Test ontology",
},
)
self.assertEqual(ontology_response.status_code, 200)

View file

@ -0,0 +1,76 @@
import os
import asyncio
import pathlib
from uuid import UUID
import cognee
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.modules.data.methods.delete_dataset import delete_dataset
from cognee.modules.data.methods.get_dataset import get_dataset
from cognee.modules.users.methods import get_default_user
async def main():
# Set data and system directory paths
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_dataset_delete")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_dataset_delete")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
# Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
# Add the text, and make it available for cognify
await cognee.add(text, "nlp_dataset")
await cognee.add("Quantum computing is the study of quantum computers.", "quantum_dataset")
# Use LLMs and cognee to create knowledge graph
ret_val = await cognee.cognify()
user = await get_default_user()
for val in ret_val:
dataset_id = str(val)
vector_db_path = os.path.join(
cognee_directory_path, "databases", str(user.id), dataset_id + ".lance.db"
)
graph_db_path = os.path.join(
cognee_directory_path, "databases", str(user.id), dataset_id + ".pkl"
)
# Check if databases are properly created and exist before deletion
assert os.path.exists(graph_db_path), "Graph database file not found."
assert os.path.exists(vector_db_path), "Vector database file not found."
dataset = await get_dataset(user_id=user.id, dataset_id=UUID(dataset_id))
await delete_dataset(dataset)
# Confirm databases have been deleted
assert not os.path.exists(graph_db_path), "Graph database file found."
assert not os.path.exists(vector_db_path), "Vector database file found."
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,5 +1,10 @@
import pathlib
import os
import asyncio
import pytest
import pytest_asyncio
from collections import Counter
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from collections import Counter
logger = get_logger()
async def main():
# This test runs for multiple db settings, to run this locally set the corresponding db envs
async def _reset_engines_and_prune() -> None:
"""Reset db engine caches and prune data/system.
Kept intentionally identical to the inlined setup logic to avoid event loop issues when
using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
"""
# Dispose of existing engines and clear caches to ensure fresh instances for each test
try:
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
# Dispose SQLAlchemy engine connection pool if it exists
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
await vector_engine.engine.dispose(close=True)
except Exception:
# Engine might not exist yet
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "test_dataset"
async def _seed_default_dataset(dataset_name: str) -> dict:
"""Add the shared test dataset contents and run cognify (same steps/order as before)."""
text_1 = """Germany is located in europe right next to the Netherlands"""
logger.info(f"Adding text data to dataset: {dataset_name}")
await cognee.add(text_1, dataset_name)
explanation_file_path_quantum = os.path.join(
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
)
logger.info(f"Adding file data to dataset: {dataset_name}")
await cognee.add([explanation_file_path_quantum], dataset_name)
logger.info(f"Running cognify on dataset: {dataset_name}")
await cognee.cognify([dataset_name])
return {
"dataset_name": dataset_name,
"text_1": text_1,
"explanation_file_path_quantum": explanation_file_path_quantum,
}
@pytest.fixture(scope="session")
def event_loop():
"""Use a single asyncio event loop for this test module.
This helps avoid "Future attached to a different loop" when running multiple async
tests that share clients/engines.
"""
loop = asyncio.new_event_loop()
try:
yield loop
finally:
loop.close()
async def setup_test_environment():
"""Helper function to set up test environment with data, cognify, and triplet embeddings."""
# This test runs for multiple db settings, to run this locally set the corresponding db envs
dataset_name = "test_dataset"
logger.info("Starting test setup: pruning data and system")
await _reset_engines_and_prune()
state = await _seed_default_dataset(dataset_name=dataset_name)
user = await get_default_user()
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
logger.info("Creating triplet embeddings")
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
# Check if Triplet_text collection was created
vector_engine = get_vector_engine()
has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
logger.info(f"Triplet_text collection exists after creation: {has_collection}")
if has_collection:
collection = await vector_engine.get_collection("Triplet_text")
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
logger.info(f"Triplet_text collection row count: {count}")
return state
async def setup_test_environment_for_feedback():
"""Helper function to set up test environment for feedback weight calculation test."""
dataset_name = "test_dataset"
await _reset_engines_and_prune()
return await _seed_default_dataset(dataset_name=dataset_name)
@pytest_asyncio.fixture(scope="session")
async def e2e_state():
"""Compute E2E artifacts once; tests only assert.
This avoids repeating expensive setup and LLM calls across multiple tests.
"""
await setup_test_environment()
# --- Graph/vector engine consistency ---
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
_nodes, edges = await graph_engine.get_graph_data()
vector_engine = get_vector_engine()
collection = await vector_engine.search(
query_text="Test", limit=None, collection_name="Triplet_text"
collection_name="Triplet_text", query_text="Test", limit=None
)
assert len(edges) == len(collection), (
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
)
# --- Retriever contexts ---
query = "Next to which country is Germany located?"
context_gk = await GraphCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_cot = await GraphCompletionCotRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_triplet = await TripletRetriever().get_context(
query="Next to which country is Germany located?"
)
contexts = {
"graph_completion": await GraphCompletionRetriever().get_context(query=query),
"graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
query=query
),
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
query=query
),
"chunks": await ChunksRetriever(top_k=5).get_context(query=query),
"summaries": await SummariesRetriever(top_k=5).get_context(query=query),
"rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
"temporal": await TemporalRetriever(top_k=5).get_context(query=query),
"triplet": await TripletRetriever().get_context(query=query),
}
for name, context in [
("GraphCompletionRetriever", context_gk),
("GraphCompletionCotRetriever", context_gk_cot),
("GraphCompletionContextExtensionRetriever", context_gk_ext),
("GraphSummaryCompletionRetriever", context_gk_sum),
]:
assert isinstance(context, list), f"{name}: Context should be a list"
assert len(context) > 0, f"{name}: Context should not be empty"
context_text = await resolve_edges_to_text(context)
lower = context_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
lower_triplet = context_triplet.lower()
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
)
triplets_gk = await GraphCompletionRetriever().get_triplets(
query="Next to which country is Germany located?"
)
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
query="Next to which country is Germany located?"
)
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
query="Next to which country is Germany located?"
)
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
query="Next to which country is Germany located?"
)
for name, triplets in [
("GraphCompletionRetriever", triplets_gk),
("GraphCompletionCotRetriever", triplets_gk_cot),
("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
("GraphSummaryCompletionRetriever", triplets_gk_sum),
]:
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), (
f"{name}: vector_distance should be float, got {type(distance)}"
)
assert 0 <= distance <= 1, (
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node1_distance <= 1, (
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node2_distance <= 1, (
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
)
# --- Retriever triplets + vector distance validation ---
triplets = {
"graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
"graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
query=query
),
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
query=query
),
}
# --- Search operations + graph side effects ---
completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Where is germany located, next to which country?",
@ -164,6 +214,26 @@ async def main():
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_chunks = await cognee.search(
query_type=SearchType.CHUNKS,
query_text="Germany",
save_interaction=False,
)
completion_summaries = await cognee.search(
query_type=SearchType.SUMMARIES,
query_text="Germany",
save_interaction=False,
)
completion_rag = await cognee.search(
query_type=SearchType.RAG_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=False,
)
completion_temporal = await cognee.search(
query_type=SearchType.TEMPORAL,
query_text="Next to which country is Germany located?",
save_interaction=False,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
@ -171,134 +241,217 @@ async def main():
last_k=1,
)
for name, search_results in [
("GRAPH_COMPLETION", completion_gk),
("GRAPH_COMPLETION_COT", completion_cot),
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
("GRAPH_SUMMARY_COMPLETION", completion_sum),
("TRIPLET_COMPLETION", completion_triplet),
]:
assert isinstance(search_results, list), f"{name}: should return a list"
assert len(search_results) == 1, (
f"{name}: expected single-element list, got {len(search_results)}"
)
# Snapshot after all E2E operations above (used by assertion-only tests).
graph_snapshot = await (await get_graph_engine()).get_graph_data()
from cognee.context_global_variables import backend_access_control_enabled
return {
"graph_edges": edges,
"triplet_collection": collection,
"vector_collection_edges_count": len(collection),
"graph_edges_count": len(edges),
"contexts": contexts,
"triplets": triplets,
"search_results": {
"graph_completion": completion_gk,
"graph_completion_cot": completion_cot,
"graph_completion_context_extension": completion_ext,
"graph_summary_completion": completion_sum,
"triplet_completion": completion_triplet,
"chunks": completion_chunks,
"summaries": completion_summaries,
"rag_completion": completion_rag,
"temporal": completion_temporal,
},
"graph_snapshot": graph_snapshot,
}
if backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
# Assert there are exactly 4 CogneeUserInteraction nodes.
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
)
# Assert there is exactly two CogneeUserFeedback nodes.
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
)
# Assert there is exactly two NodeSet.
assert type_counts.get("NodeSet", 0) == 2, (
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
)
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
)
# Assert that there are exactly 2 'gives_feedback_to' edges.
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
)
# Assert that there are at least 6 'belongs_to_set' edges.
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
)
nodes = graph[0]
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes:
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
)
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys()), (
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
)
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip(), (
f"Node {node_id} has invalid value for '{field}': {value!r}"
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text_1, dataset_name)
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
@pytest_asyncio.fixture(scope="session")
async def feedback_state():
"""Feedback-weight scenario computed once (fresh environment)."""
await setup_test_environment_for_feedback()
await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This was the best answer I've ever seen",
last_k=1,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="Wow the correctness of this answer blows my mind",
last_k=1,
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
return {"graph_snapshot": graph}
edges = graph[1]
for from_node, to_node, relationship_name, properties in edges:
@pytest.mark.asyncio
async def test_e2e_graph_vector_consistency(e2e_state):
"""Graph and vector stores contain the same triplet edges."""
assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
@pytest.mark.asyncio
async def test_e2e_retriever_contexts(e2e_state):
"""All retrievers return non-empty, well-typed contexts."""
contexts = e2e_state["contexts"]
for name in [
"graph_completion",
"graph_completion_cot",
"graph_completion_context_extension",
"graph_summary_completion",
]:
ctx = contexts[name]
assert isinstance(ctx, list), f"{name}: Context should be a list"
assert ctx, f"{name}: Context should not be empty"
ctx_text = await resolve_edges_to_text(ctx)
lower = ctx_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
)
triplet_ctx = contexts["triplet"]
assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
assert triplet_ctx.strip(), "triplet: Context should not be empty"
chunks_ctx = contexts["chunks"]
assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
assert chunks_ctx, "chunks: Context should not be empty"
chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
assert "germany" in chunks_text or "netherlands" in chunks_text
summaries_ctx = contexts["summaries"]
assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
assert summaries_ctx, "summaries: Context should not be empty"
assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
rag_ctx = contexts["rag_completion"]
assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
assert rag_ctx.strip(), "rag_completion: Context should not be empty"
temporal_ctx = contexts["temporal"]
assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
assert temporal_ctx.strip(), "temporal: Context should not be empty"
@pytest.mark.asyncio
async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
"""Graph retriever triplets include sane vector_distance metadata."""
for name, triplets in e2e_state["triplets"].items():
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), f"{name}: vector_distance should be float"
assert 0 <= distance <= 1
assert 0 <= node1_distance <= 1
assert 0 <= node2_distance <= 1
@pytest.mark.asyncio
async def test_e2e_search_results_and_wrappers(e2e_state):
"""Search returns expected shapes across search types and access modes."""
from cognee.context_global_variables import backend_access_control_enabled
sr = e2e_state["search_results"]
# Completion-like search types: validate wrapper + content
for name in [
"graph_completion",
"graph_completion_cot",
"graph_completion_context_extension",
"graph_summary_completion",
"triplet_completion",
"rag_completion",
"temporal",
]:
search_results = sr[name]
assert isinstance(search_results, list), f"{name}: should return a list"
assert len(search_results) == 1, f"{name}: expected single-element list"
if backend_access_control_enabled():
wrapper = search_results[0]
assert isinstance(wrapper, dict), (
f"{name}: expected wrapper dict in access control mode"
)
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
assert wrapper.get("dataset_name") == "test_dataset"
assert "graphs" in wrapper
text = wrapper["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str) and text.strip()
assert "netherlands" in text.lower()
# Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
for name in ["chunks", "summaries"]:
search_results = sr[name]
assert isinstance(search_results, list), f"{name}: should return a list"
assert search_results, f"{name}: should not be empty"
first = search_results[0]
assert isinstance(first, dict), f"{name}: expected dict entries"
payloads = search_results
if "search_result" in first and "text" not in first:
payloads = (first.get("search_result") or [None])[0]
assert isinstance(payloads, list) and payloads
assert isinstance(payloads[0], dict)
assert str(payloads[0].get("text", "")).strip()
@pytest.mark.asyncio
async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
"""Search interactions create expected graph nodes/edges and required fields."""
graph = e2e_state["graph_snapshot"]
nodes, edges = graph
type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
edge_type_counts = Counter(edge_type[2] for edge_type in edges)
assert type_counts.get("CogneeUserInteraction", 0) == 4
assert type_counts.get("CogneeUserFeedback", 0) == 2
assert type_counts.get("NodeSet", 0) == 2
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
assert edge_type_counts.get("gives_feedback_to", 0) == 2
assert edge_type_counts.get("belongs_to_set", 0) >= 6
required_fields_user_interaction = {"question", "answer", "context"}
required_fields_feedback = {"feedback", "sentiment"}
for node_id, data in nodes:
if data.get("type") == "CogneeUserInteraction":
assert required_fields_user_interaction.issubset(data.keys())
for field in required_fields_user_interaction:
value = data[field]
assert isinstance(value, str) and value.strip()
if data.get("type") == "CogneeUserFeedback":
assert required_fields_feedback.issubset(data.keys())
for field in required_fields_feedback:
value = data[field]
assert isinstance(value, str) and value.strip()
@pytest.mark.asyncio
async def test_e2e_feedback_weight_calculation(feedback_state):
"""Positive feedback increases used_graph_element_to_answer feedback_weight."""
_nodes, edges = feedback_state["graph_snapshot"]
for _from_node, _to_node, relationship_name, properties in edges:
if relationship_name == "used_graph_element_to_answer":
assert properties["feedback_weight"] >= 6, (
"Feedback weight calculation is not correct, it should be more then 6."
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,17 +1,28 @@
import pytest
import uuid
from fastapi.testclient import TestClient
from unittest.mock import patch, Mock, AsyncMock
from unittest.mock import Mock
from types import SimpleNamespace
import importlib
from cognee.api.client import app
from cognee.modules.users.methods import get_authenticated_user
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@pytest.fixture(scope="session")
def test_client():
# Keep a single TestClient (and event loop) for the whole module.
# Re-creating TestClient repeatedly can break async DB connections (asyncpg loop mismatch).
with TestClient(app) as c:
yield c
@pytest.fixture
def client():
return TestClient(app)
def client(test_client, mock_default_user):
async def override_get_authenticated_user():
return mock_default_user
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
yield test_client
app.dependency_overrides.pop(get_authenticated_user, None)
@pytest.fixture
@ -32,12 +43,8 @@ def mock_default_user():
)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
def test_upload_ontology_success(client):
"""Test successful ontology upload"""
import json
mock_get_default_user.return_value = mock_default_user
ontology_content = (
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
)
@ -46,7 +53,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
data={"ontology_key": unique_key, "description": "Test"},
)
assert response.status_code == 200
@ -55,10 +62,8 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
def test_upload_ontology_invalid_file(client):
"""Test 400 response for non-.owl files"""
mock_get_default_user.return_value = mock_default_user
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
response = client.post(
"/api/v1/ontologies",
@ -68,14 +73,10 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul
assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
def test_upload_ontology_missing_data(client):
"""Test 400 response for missing file or key"""
import json
mock_get_default_user.return_value = mock_default_user
# Missing file
response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
response = client.post("/api/v1/ontologies", data={"ontology_key": "test"})
assert response.status_code == 400
# Missing key
@ -85,34 +86,25 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul
assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
"""Test behavior when default user is provided (no explicit authentication)"""
import json
def test_upload_ontology_without_auth_header(client):
"""Test behavior when no explicit authentication header is provided."""
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
mock_get_default_user.return_value = mock_default_user
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
data={"ontology_key": json.dumps([unique_key])},
data={"ontology_key": unique_key},
)
# The current system provides a default user when no explicit authentication is given
# This test verifies the system works with conditional authentication
assert response.status_code == 200
data = response.json()
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
"""Test uploading multiple ontology files in single request"""
def test_upload_multiple_ontologies_in_single_request_is_rejected(client):
"""Uploading multiple ontology files in a single request should fail."""
import io
mock_get_default_user.return_value = mock_default_user
# Create mock files
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
@ -120,45 +112,34 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
]
data = {
"ontology_key": '["vehicles", "manufacturers"]',
"descriptions": '["Base vehicles", "Car manufacturers"]',
}
data = {"ontology_key": "vehicles", "description": "Base vehicles"}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200
result = response.json()
assert "uploaded_ontologies" in result
assert len(result["uploaded_ontologies"]) == 2
assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
assert response.status_code == 400
assert "Only one ontology_file is allowed" in response.json()["error"]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
"""Test that upload endpoint accepts array parameters"""
def test_upload_endpoint_rejects_array_style_fields(client):
"""Array-style form values should be rejected (no backwards compatibility)."""
import io
import json
mock_get_default_user.return_value = mock_default_user
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
data = {
"ontology_key": json.dumps(["single_key"]),
"descriptions": json.dumps(["Single ontology"]),
"description": json.dumps(["Single ontology"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200
result = response.json()
assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
assert response.status_code == 400
assert "ontology_key must be a string" in response.json()["error"]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
def test_cognify_with_multiple_ontologies(client):
"""Test cognify endpoint accepts multiple ontology keys"""
payload = {
"datasets": ["test_dataset"],
@ -172,14 +153,11 @@ def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_de
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
"""Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
def test_complete_multifile_workflow(client):
"""Test workflow: upload ontologies one-by-one → cognify with multiple keys"""
import io
import json
mock_get_default_user.return_value = mock_default_user
# Step 1: Upload multiple ontologies
# Step 1: Upload two ontologies (one-by-one)
file1_content = b"""<?xml version="1.0"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#">
@ -192,17 +170,21 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
<owl:Class rdf:ID="Manufacturer"/>
</rdf:RDF>"""
files = [
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
]
data = {
"ontology_key": json.dumps(["vehicles", "manufacturers"]),
"descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
}
upload_response_1 = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml"))],
data={"ontology_key": "vehicles", "description": "Vehicle ontology"},
)
assert upload_response_1.status_code == 200
upload_response = client.post("/api/v1/ontologies", files=files, data=data)
assert upload_response.status_code == 200
upload_response_2 = client.post(
"/api/v1/ontologies",
files=[
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml"))
],
data={"ontology_key": "manufacturers", "description": "Manufacturer ontology"},
)
assert upload_response_2.status_code == 200
# Step 2: Verify ontologies are listed
list_response = client.get("/api/v1/ontologies")
@ -223,44 +205,42 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
assert cognify_response.status_code != 400 # Not a validation error
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
"""Test error handling for invalid multifile uploads"""
def test_upload_error_handling(client):
"""Test error handling for invalid uploads (single-file endpoint)."""
import io
import json
# Test mismatched array lengths
# Array-style key should be rejected
file_content = b"<rdf:RDF></rdf:RDF>"
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
data = {
"ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
"descriptions": json.dumps(["desc1"]),
"ontology_key": json.dumps(["key1", "key2"]),
"description": "desc1",
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400
assert "Number of keys must match number of files" in response.json()["error"]
assert "ontology_key must be a string" in response.json()["error"]
# Test duplicate keys
files = [
("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
]
data = {
"ontology_key": json.dumps(["duplicate", "duplicate"]),
"descriptions": json.dumps(["desc1", "desc2"]),
}
# Duplicate key should be rejected
response_1 = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml"))],
data={"ontology_key": "duplicate", "description": "desc1"},
)
assert response_1.status_code == 200
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400
assert "Duplicate ontology keys not allowed" in response.json()["error"]
response_2 = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml"))],
data={"ontology_key": "duplicate", "description": "desc2"},
)
assert response_2.status_code == 400
assert "already exists" in response_2.json()["error"]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
def test_cognify_missing_ontology_key(client):
"""Test cognify with non-existent ontology key"""
mock_get_default_user.return_value = mock_default_user
payload = {
"datasets": ["test_dataset"],
"ontology_key": ["nonexistent_key"],

View file

@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
"""
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
ADAPTER_CLASSES = [
HotpotQAAdapter,
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
adapter = AdapterClass()
result = adapter.load_corpus()
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
result = adapter.load_corpus()
else:
adapter = AdapterClass()
result = adapter.load_corpus()
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
else:
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)

View file

@ -2,15 +2,38 @@ import pytest
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
from cognee.infrastructure.databases.graph import get_graph_engine
from unittest.mock import AsyncMock, patch
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
@pytest.mark.parametrize("benchmark", benchmark_options)
def test_corpus_builder_load_corpus(benchmark):
limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
)

View file

@ -1,7 +1,7 @@
[project]
name = "cognee"
version = "0.5.0.dev0"
version = "0.5.0.dev1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },

4
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10, <3.14"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
@ -946,7 +946,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.5.0.dev0"
version = "0.5.0.dev1"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },