diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml
index cb69e9ef6..8cd62910c 100644
--- a/.github/workflows/e2e_tests.yml
+++ b/.github/workflows/e2e_tests.yml
@@ -237,6 +237,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_database_handler.py
+ test-dataset-database-deletion:
+ name: Test dataset database deletion in Cognee
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+
+ - name: Run dataset databases deletion test
+ env:
+ ENV: 'dev'
+ LLM_MODEL: ${{ secrets.LLM_MODEL }}
+ LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
+ LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+ run: uv run python ./cognee/tests/test_dataset_delete.py
+
test-permissions:
name: Test permissions with different situations in Cognee
runs-on: ubuntu-22.04
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 000000000..ff2f809f3
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,154 @@
+name: release.yml
+on:
+ workflow_dispatch:
+ inputs:
+ flavour:
+ required: true
+ default: dev
+ type: choice
+ options:
+ - dev
+ - main
+ description: Dev or Main release
+ test_mode:
+ required: true
+ type: boolean
+ description: Aka Dry Run. If true, it won't affect public indices or repositories
+
+jobs:
+ release-github:
+ name: Create GitHub Release from ${{ inputs.flavour }}
+ outputs:
+ tag: ${{ steps.create_tag.outputs.tag }}
+ version: ${{ steps.create_tag.outputs.version }}
+ permissions:
+ contents: write
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out ${{ inputs.flavour }}
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.flavour }}
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+
+ - name: Create and push git tag
+ id: create_tag
+ env:
+ TEST_MODE: ${{ inputs.test_mode }}
+ run: |
+ VERSION="$(uv version --short)"
+ TAG="v${VERSION}"
+
+ echo "Tag to create: ${TAG}"
+
+ git config user.name "github-actions[bot]"
+ git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
+
+ echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
+ echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
+
+ if [ "$TEST_MODE" = "false" ]; then
+ git tag "${TAG}"
+ git push origin "${TAG}"
+ else
+ echo "Test mode is enabled. Skipping tag creation and push."
+ fi
+
+ - name: Create GitHub Release
+ uses: softprops/action-gh-release@v2
+ with:
+ tag_name: ${{ steps.create_tag.outputs.tag }}
+ generate_release_notes: true
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ release-pypi-package:
+ needs: release-github
+ name: Release PyPI Package from ${{ inputs.flavour }}
+ permissions:
+ contents: read
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out ${{ inputs.flavour }}
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.flavour }}
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+
+ - name: Install Python
+ run: uv python install
+
+ - name: Install dependencies
+ run: uv sync --locked --all-extras
+
+ - name: Build distributions
+ run: uv build
+
+ - name: Publish ${{ inputs.flavour }} release to TestPyPI
+ if: ${{ inputs.test_mode }}
+ env:
+ UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }}
+ run: uv publish --publish-url https://test.pypi.org/legacy/
+
+ - name: Publish ${{ inputs.flavour }} release to PyPI
+ if: ${{ !inputs.test_mode }}
+ env:
+ UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
+ run: uv publish
+
+ release-docker-image:
+ needs: release-github
+ name: Release Docker Image from ${{ inputs.flavour }}
+ permissions:
+ contents: read
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out ${{ inputs.flavour }}
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.flavour }}
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_PASSWORD }}
+
+ - name: Build and push Dev Docker Image
+ if: ${{ inputs.flavour == 'dev' }}
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: ${{ !inputs.test_mode }}
+ tags: cognee/cognee:${{ needs.release-github.outputs.version }}
+ labels: |
+ version=${{ needs.release-github.outputs.version }}
+ flavour=${{ inputs.flavour }}
+ cache-from: type=registry,ref=cognee/cognee:buildcache
+ cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
+
+ - name: Build and push Main Docker Image
+ if: ${{ inputs.flavour == 'main' }}
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: ${{ !inputs.test_mode }}
+ tags: |
+ cognee/cognee:${{ needs.release-github.outputs.version }}
+ cognee/cognee:latest
+ labels: |
+ version=${{ needs.release-github.outputs.version }}
+ flavour=${{ inputs.flavour }}
+ cache-from: type=registry,ref=cognee/cognee:buildcache
+ cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml
index 118c1c06c..f0c7817cd 100644
--- a/.github/workflows/search_db_tests.yml
+++ b/.github/workflows/search_db_tests.yml
@@ -11,12 +11,21 @@ on:
type: string
default: "all"
description: "Which vector databases to test (comma-separated list or 'all')"
+ python-versions:
+ required: false
+ type: string
+ default: '["3.10", "3.11", "3.12", "3.13"]'
+ description: "Python versions to test (JSON array)"
jobs:
run-kuzu-lance-sqlite-search-tests:
- name: Search test for Kuzu/LanceDB/Sqlite
+ name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
+ strategy:
+ matrix:
+ python-version: ${{ fromJSON(inputs.python-versions) }}
+ fail-fast: false
steps:
- name: Check out
uses: actions/checkout@v4
@@ -26,7 +35,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
- python-version: ${{ inputs.python-version }}
+ python-version: ${{ matrix.python-version }}
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@@ -45,13 +54,16 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
- run: uv run python ./cognee/tests/test_search_db.py
+ run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-neo4j-lance-sqlite-search-tests:
- name: Search test for Neo4j/LanceDB/Sqlite
+ name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
-
+ strategy:
+ matrix:
+ python-version: ${{ fromJSON(inputs.python-versions) }}
+ fail-fast: false
steps:
- name: Check out
uses: actions/checkout@v4
@@ -61,7 +73,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
- python-version: ${{ inputs.python-version }}
+ python-version: ${{ matrix.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
@@ -88,12 +100,16 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
- run: uv run python ./cognee/tests/test_search_db.py
+ run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-kuzu-pgvector-postgres-search-tests:
- name: Search test for Kuzu/PGVector/Postgres
+ name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
+ strategy:
+ matrix:
+ python-version: ${{ fromJSON(inputs.python-versions) }}
+ fail-fast: false
services:
postgres:
image: pgvector/pgvector:pg17
@@ -117,7 +133,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
- python-version: ${{ inputs.python-version }}
+ python-version: ${{ matrix.python-version }}
extra-dependencies: "postgres"
- name: Dependencies already installed
@@ -143,12 +159,16 @@ jobs:
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
- run: uv run python ./cognee/tests/test_search_db.py
+ run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
run-neo4j-pgvector-postgres-search-tests:
- name: Search test for Neo4j/PGVector/Postgres
+ name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }})
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
+ strategy:
+ matrix:
+ python-version: ${{ fromJSON(inputs.python-versions) }}
+ fail-fast: false
services:
postgres:
image: pgvector/pgvector:pg17
@@ -172,7 +192,7 @@ jobs:
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
- python-version: ${{ inputs.python-version }}
+ python-version: ${{ matrix.python-version }}
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
@@ -205,4 +225,4 @@ jobs:
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
- run: uv run python ./cognee/tests/test_search_db.py
+ run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml
index 6b0221309..8f9d30d10 100644
--- a/.github/workflows/test_llms.yml
+++ b/.github/workflows/test_llms.yml
@@ -84,3 +84,93 @@ jobs:
EMBEDDING_DIMENSIONS: "3072"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
+
+ test-bedrock-api-key:
+ name: Run Bedrock API Key Test
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+ extra-dependencies: "aws"
+
+ - name: Run Bedrock API Key Simple Example
+ env:
+ LLM_PROVIDER: "bedrock"
+ LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
+ LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
+ LLM_MAX_TOKENS: "16384"
+ AWS_REGION_NAME: "eu-west-1"
+ EMBEDDING_PROVIDER: "bedrock"
+ EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
+ EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
+ EMBEDDING_DIMENSIONS: "1024"
+ EMBEDDING_MAX_TOKENS: "8191"
+ run: uv run python ./examples/python/simple_example.py
+
+ test-bedrock-aws-credentials:
+ name: Run Bedrock AWS Credentials Test
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+ extra-dependencies: "aws"
+
+ - name: Run Bedrock AWS Credentials Simple Example
+ env:
+ LLM_PROVIDER: "bedrock"
+ LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
+ LLM_MAX_TOKENS: "16384"
+ AWS_REGION_NAME: "eu-west-1"
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ EMBEDDING_PROVIDER: "bedrock"
+ EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
+ EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
+ EMBEDDING_DIMENSIONS: "1024"
+ EMBEDDING_MAX_TOKENS: "8191"
+ run: uv run python ./examples/python/simple_example.py
+
+ test-bedrock-aws-profile:
+ name: Run Bedrock AWS Profile Test
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+ extra-dependencies: "aws"
+
+ - name: Configure AWS Profile
+ run: |
+ mkdir -p ~/.aws
+ cat > ~/.aws/credentials << EOF
+ [bedrock-test]
+ aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ EOF
+
+ - name: Run Bedrock AWS Profile Simple Example
+ env:
+ LLM_PROVIDER: "bedrock"
+ LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
+ LLM_MAX_TOKENS: "16384"
+ AWS_PROFILE_NAME: "bedrock-test"
+ AWS_REGION_NAME: "eu-west-1"
+ EMBEDDING_PROVIDER: "bedrock"
+ EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
+ EMBEDDING_DIMENSIONS: "1024"
+ EMBEDDING_MAX_TOKENS: "8191"
+ run: uv run python ./examples/python/simple_example.py
\ No newline at end of file
diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py
index 23160d8b2..bce7f807f 100755
--- a/cognee-mcp/src/test_client.py
+++ b/cognee-mcp/src/test_client.py
@@ -3,7 +3,7 @@
Test client for Cognee MCP Server functionality.
This script tests all the tools and functions available in the Cognee MCP server,
-including cognify, codify, search, prune, status checks, and utility functions.
+including cognify, search, prune, status checks, and utility functions.
Usage:
# Set your OpenAI API key first
@@ -23,6 +23,7 @@ import tempfile
import time
from contextlib import asynccontextmanager
from cognee.shared.logging_utils import setup_logging
+from logging import ERROR, INFO
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -35,7 +36,7 @@ from src.server import (
load_class,
)
-# Set timeout for cognify/codify to complete in
+# Set timeout for cognify to complete in
TIMEOUT = 5 * 60 # 5 min in seconds
@@ -151,12 +152,9 @@ DEBUG = True
expected_tools = {
"cognify",
- "codify",
"search",
"prune",
"cognify_status",
- "codify_status",
- "cognee_add_developer_rules",
"list_data",
"delete",
}
@@ -247,106 +245,6 @@ DEBUG = True
}
print(f"โ {test_name} test failed: {e}")
- async def test_codify(self):
- """Test the codify functionality using MCP client."""
- print("\n๐งช Testing codify functionality...")
- try:
- async with self.mcp_server_session() as session:
- codify_result = await session.call_tool(
- "codify", arguments={"repo_path": self.test_repo_dir}
- )
-
- start = time.time() # mark the start
- while True:
- try:
- # Wait a moment
- await asyncio.sleep(5)
-
- # Check if codify processing is finished
- status_result = await session.call_tool("codify_status", arguments={})
- if hasattr(status_result, "content") and status_result.content:
- status_text = (
- status_result.content[0].text
- if status_result.content
- else str(status_result)
- )
- else:
- status_text = str(status_result)
-
- if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
- break
- elif time.time() - start > TIMEOUT:
- raise TimeoutError("Codify did not complete in 5min")
- except DatabaseNotCreatedError:
- if time.time() - start > TIMEOUT:
- raise TimeoutError("Database was not created in 5min")
-
- self.test_results["codify"] = {
- "status": "PASS",
- "result": codify_result,
- "message": "Codify executed successfully",
- }
- print("โ
Codify test passed")
-
- except Exception as e:
- self.test_results["codify"] = {
- "status": "FAIL",
- "error": str(e),
- "message": "Codify test failed",
- }
- print(f"โ Codify test failed: {e}")
-
- async def test_cognee_add_developer_rules(self):
- """Test the cognee_add_developer_rules functionality using MCP client."""
- print("\n๐งช Testing cognee_add_developer_rules functionality...")
- try:
- async with self.mcp_server_session() as session:
- result = await session.call_tool(
- "cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
- )
-
- start = time.time() # mark the start
- while True:
- try:
- # Wait a moment
- await asyncio.sleep(5)
-
- # Check if developer rule cognify processing is finished
- status_result = await session.call_tool("cognify_status", arguments={})
- if hasattr(status_result, "content") and status_result.content:
- status_text = (
- status_result.content[0].text
- if status_result.content
- else str(status_result)
- )
- else:
- status_text = str(status_result)
-
- if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
- break
- elif time.time() - start > TIMEOUT:
- raise TimeoutError(
- "Cognify of developer rules did not complete in 5min"
- )
- except DatabaseNotCreatedError:
- if time.time() - start > TIMEOUT:
- raise TimeoutError("Database was not created in 5min")
-
- self.test_results["cognee_add_developer_rules"] = {
- "status": "PASS",
- "result": result,
- "message": "Developer rules addition executed successfully",
- }
- print("โ
Developer rules test passed")
-
- except Exception as e:
- self.test_results["cognee_add_developer_rules"] = {
- "status": "FAIL",
- "error": str(e),
- "message": "Developer rules test failed",
- }
- print(f"โ Developer rules test failed: {e}")
-
async def test_search_functionality(self):
"""Test the search functionality with different search types using MCP client."""
print("\n๐งช Testing search functionality...")
@@ -359,7 +257,11 @@ DEBUG = True
# Go through all Cognee search types
for search_type in SearchType:
# Don't test these search types
- if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
+ if search_type in [
+ SearchType.NATURAL_LANGUAGE,
+ SearchType.CYPHER,
+ SearchType.TRIPLET_COMPLETION,
+ ]:
break
try:
async with self.mcp_server_session() as session:
@@ -681,9 +583,6 @@ class TestModel:
test_name="Cognify2",
)
- await self.test_codify()
- await self.test_cognee_add_developer_rules()
-
# Test list_data and delete functionality
await self.test_list_data()
await self.test_delete()
@@ -739,7 +638,5 @@ async def main():
if __name__ == "__main__":
- from logging import ERROR
-
logger = setup_logging(log_level=ERROR)
asyncio.run(main())
diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py
index 1ea4caca4..90ea32ae7 100644
--- a/cognee/api/v1/add/add.py
+++ b/cognee/api/v1/add/add.py
@@ -155,7 +155,7 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
+ - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
- LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password
diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py
index 9862edd49..ffc903d68 100644
--- a/cognee/api/v1/cognify/cognify.py
+++ b/cognee/api/v1/cognify/cognify.py
@@ -53,6 +53,7 @@ async def cognify(
custom_prompt: Optional[str] = None,
temporal_cognify: bool = False,
data_per_batch: int = 20,
+ **kwargs,
):
"""
Transform ingested data into a structured knowledge graph.
@@ -223,6 +224,7 @@ async def cognify(
config=config,
custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch,
+ **kwargs,
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config: Config = None,
custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100,
+ **kwargs,
) -> list[Task]:
if config is None:
ontology_config = get_ontology_env_config()
@@ -288,6 +291,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config=config,
custom_prompt=custom_prompt,
task_config={"batch_size": chunks_per_batch},
+ **kwargs,
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py
index 4f1497e3c..a499b3ca3 100644
--- a/cognee/api/v1/cognify/routers/get_cognify_router.py
+++ b/cognee/api/v1/cognify/routers/get_cognify_router.py
@@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO):
default="", description="Custom prompt for entity extraction and graph generation"
)
ontology_key: Optional[List[str]] = Field(
- default=None, description="Reference to one or more previously uploaded ontologies"
+ default=None,
+ examples=[[]],
+ description="Reference to one or more previously uploaded ontologies",
)
diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py
index eff87b3af..ca738dfbe 100644
--- a/cognee/api/v1/datasets/routers/get_datasets_router.py
+++ b/cognee/api/v1/datasets/routers/get_datasets_router.py
@@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
},
)
- from cognee.modules.data.methods import get_dataset, delete_dataset
+ from cognee.modules.data.methods import delete_dataset
- dataset = await get_dataset(user.id, dataset_id)
+ dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
if dataset is None:
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
- await delete_dataset(dataset)
+ await delete_dataset(dataset[0])
@router.delete(
"/{dataset_id}/data/{data_id}",
diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py
index ee31c683f..77667d88d 100644
--- a/cognee/api/v1/ontologies/routers/get_ontology_router.py
+++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py
@@ -1,4 +1,4 @@
-from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
+from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
from fastapi.responses import JSONResponse
from typing import Optional, List
@@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter:
@router.post("", response_model=dict)
async def upload_ontology(
+ request: Request,
ontology_key: str = Form(...),
- ontology_file: List[UploadFile] = File(...),
- descriptions: Optional[str] = Form(None),
+ ontology_file: UploadFile = File(...),
+ description: Optional[str] = Form(None),
user: User = Depends(get_authenticated_user),
):
"""
- Upload ontology files with their respective keys for later use in cognify operations.
-
- Supports both single and multiple file uploads:
- - Single file: ontology_key=["key"], ontology_file=[file]
- - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
+ Upload a single ontology file for later use in cognify operations.
## Request Parameters
- - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
- - **ontology_file** (List[UploadFile]): OWL format ontology files
- - **descriptions** (Optional[str]): JSON array string of optional descriptions
+ - **ontology_key** (str): User-defined identifier for the ontology.
+ - **ontology_file** (UploadFile): Single OWL format ontology file
+ - **description** (Optional[str]): Optional description for the ontology.
## Response
- Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
+ Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
## Error Codes
- - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
+ - **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
@@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter:
)
try:
- import json
+ # Enforce: exactly one uploaded file for "ontology_file"
+ form = await request.form()
+ uploaded_files = form.getlist("ontology_file")
+ if len(uploaded_files) != 1:
+ raise ValueError("Only one ontology_file is allowed")
- ontology_keys = json.loads(ontology_key)
- description_list = json.loads(descriptions) if descriptions else None
+ if ontology_key.strip().startswith(("[", "{")):
+ raise ValueError("ontology_key must be a string")
+ if description is not None and description.strip().startswith(("[", "{")):
+ raise ValueError("description must be a string")
- if not isinstance(ontology_keys, list):
- raise ValueError("ontology_key must be a JSON array")
-
- results = await ontology_service.upload_ontologies(
- ontology_keys, ontology_file, user, description_list
+ result = await ontology_service.upload_ontology(
+ ontology_key=ontology_key,
+ file=ontology_file,
+ user=user,
+ description=description,
)
return {
@@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter:
"uploaded_at": result.uploaded_at,
"description": result.description,
}
- for result in results
]
}
- except (json.JSONDecodeError, ValueError) as e:
+ except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
diff --git a/cognee/infrastructure/databases/utils/__init__.py b/cognee/infrastructure/databases/utils/__init__.py
index f31d1e0dc..3907b4325 100644
--- a/cognee/infrastructure/databases/utils/__init__.py
+++ b/cognee/infrastructure/databases/utils/__init__.py
@@ -1,2 +1,4 @@
from .get_or_create_dataset_database import get_or_create_dataset_database
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
+from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
+from .get_vector_dataset_database_handler import get_vector_dataset_database_handler
diff --git a/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py b/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py
new file mode 100644
index 000000000..d88685b48
--- /dev/null
+++ b/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py
@@ -0,0 +1,10 @@
+from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
+
+
+def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
+ supported_dataset_database_handlers,
+ )
+
+ handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
+ return handler
diff --git a/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py b/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py
new file mode 100644
index 000000000..5d1152c04
--- /dev/null
+++ b/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py
@@ -0,0 +1,10 @@
+from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
+
+
+def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
+ supported_dataset_database_handlers,
+ )
+
+ handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
+ return handler
diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py
index d33169642..561268eaf 100644
--- a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py
+++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py
@@ -1,24 +1,12 @@
+from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
+ get_graph_dataset_database_handler,
+)
+from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
+ get_vector_dataset_database_handler,
+)
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
-async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
- from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
- supported_dataset_database_handlers,
- )
-
- handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
- return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
-
-
-async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
- from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
- supported_dataset_database_handlers,
- )
-
- handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
- return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
-
-
async def resolve_dataset_database_connection_info(
dataset_database: DatasetDatabase,
) -> DatasetDatabase:
@@ -31,6 +19,12 @@ async def resolve_dataset_database_connection_info(
Returns:
DatasetDatabase instance with resolved connection info
"""
- dataset_database = await _get_vector_db_connection_info(dataset_database)
- dataset_database = await _get_graph_db_connection_info(dataset_database)
+ vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
+ graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
+ dataset_database = await vector_dataset_database_handler[
+ "handler_instance"
+ ].resolve_dataset_connection_info(dataset_database)
+ dataset_database = await graph_dataset_database_handler[
+ "handler_instance"
+ ].resolve_dataset_connection_info(dataset_database)
return dataset_database
diff --git a/cognee/infrastructure/files/storage/s3_config.py b/cognee/infrastructure/files/storage/s3_config.py
index cefe5cd2f..4cc6b1d63 100644
--- a/cognee/infrastructure/files/storage/s3_config.py
+++ b/cognee/infrastructure/files/storage/s3_config.py
@@ -9,6 +9,8 @@ class S3Config(BaseSettings):
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
+ aws_profile_name: Optional[str] = None
+ aws_bedrock_runtime_endpoint: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py
index ab5bb35d7..7bec9ca01 100644
--- a/cognee/infrastructure/llm/LLMGateway.py
+++ b/cognee/infrastructure/llm/LLMGateway.py
@@ -11,7 +11,7 @@ class LLMGateway:
@staticmethod
def acreate_structured_output(
- text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
@@ -31,7 +31,10 @@ class LLMGateway:
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
- text_input=text_input, system_prompt=system_prompt, response_model=response_model
+ text_input=text_input,
+ system_prompt=system_prompt,
+ response_model=response_model,
+ **kwargs,
)
@staticmethod
diff --git a/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py
index 59e6f563a..4a40979f4 100644
--- a/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py
+++ b/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py
@@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
async def extract_content_graph(
- content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
+ content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
):
if custom_prompt:
system_prompt = custom_prompt
@@ -30,7 +30,7 @@ async def extract_content_graph(
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
- content, system_prompt, response_model
+ content, system_prompt, response_model, **kwargs
)
return content_graph
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
index b6f218022..58b68436c 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
@@ -52,7 +52,7 @@ class AnthropicAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py
new file mode 100644
index 000000000..ad7cdf994
--- /dev/null
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py
@@ -0,0 +1,5 @@
+"""Bedrock LLM adapter module."""
+
+from .adapter import BedrockAdapter
+
+__all__ = ["BedrockAdapter"]
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py
new file mode 100644
index 000000000..1faec2d0b
--- /dev/null
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py
@@ -0,0 +1,153 @@
+import litellm
+import instructor
+from typing import Type
+from pydantic import BaseModel
+from litellm.exceptions import ContentPolicyViolationError
+from instructor.exceptions import InstructorRetryException
+
+from cognee.infrastructure.llm.LLMGateway import LLMGateway
+from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
+ LLMInterface,
+)
+from cognee.infrastructure.llm.exceptions import (
+ ContentPolicyFilterError,
+ MissingSystemPromptPathError,
+)
+from cognee.infrastructure.files.storage.s3_config import get_s3_config
+from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
+ rate_limit_async,
+ rate_limit_sync,
+ sleep_and_retry_async,
+ sleep_and_retry_sync,
+)
+from cognee.modules.observability.get_observe import get_observe
+
+observe = get_observe()
+
+
+class BedrockAdapter(LLMInterface):
+ """
+ Adapter for AWS Bedrock API with support for three authentication methods:
+ 1. API Key (Bearer Token)
+ 2. AWS Credentials (access key + secret key)
+ 3. AWS Profile (boto3 credential chain)
+ """
+
+ name = "Bedrock"
+ model: str
+ api_key: str
+ default_instructor_mode = "json_schema_mode"
+
+ MAX_RETRIES = 5
+
+ def __init__(
+ self,
+ model: str,
+ api_key: str = None,
+ max_completion_tokens: int = 16384,
+ streaming: bool = False,
+ instructor_mode: str = None,
+ ):
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
+ self.aclient = instructor.from_litellm(
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
+ )
+ self.client = instructor.from_litellm(litellm.completion)
+ self.model = model
+ self.api_key = api_key
+ self.max_completion_tokens = max_completion_tokens
+ self.streaming = streaming
+
+ def _create_bedrock_request(
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ ) -> dict:
+ """Create Bedrock request with authentication."""
+
+ request_params = {
+ "model": self.model,
+ "custom_llm_provider": "bedrock",
+ "drop_params": True,
+ "messages": [
+ {"role": "user", "content": text_input},
+ {"role": "system", "content": system_prompt},
+ ],
+ "response_model": response_model,
+ "max_retries": self.MAX_RETRIES,
+ "max_completion_tokens": self.max_completion_tokens,
+ "stream": self.streaming,
+ }
+
+ s3_config = get_s3_config()
+
+ # Add authentication parameters
+ if self.api_key:
+ request_params["api_key"] = self.api_key
+ elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
+ request_params["aws_access_key_id"] = s3_config.aws_access_key_id
+ request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
+ if s3_config.aws_session_token:
+ request_params["aws_session_token"] = s3_config.aws_session_token
+ elif s3_config.aws_profile_name:
+ request_params["aws_profile_name"] = s3_config.aws_profile_name
+
+ if s3_config.aws_region:
+ request_params["aws_region_name"] = s3_config.aws_region
+
+ # Add optional parameters
+ if s3_config.aws_bedrock_runtime_endpoint:
+ request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
+
+ return request_params
+
+ @observe(as_type="generation")
+ @sleep_and_retry_async()
+ @rate_limit_async
+ async def acreate_structured_output(
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ ) -> BaseModel:
+ """Generate structured output from AWS Bedrock API."""
+
+ try:
+ request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
+ return await self.aclient.chat.completions.create(**request_params)
+
+ except (
+ ContentPolicyViolationError,
+ InstructorRetryException,
+ ) as error:
+ if (
+ isinstance(error, InstructorRetryException)
+ and "content management policy" not in str(error).lower()
+ ):
+ raise error
+
+ raise ContentPolicyFilterError(
+ f"The provided input contains content that is not aligned with our content policy: {text_input}"
+ )
+
+ @observe
+ @sleep_and_retry_sync()
+ @rate_limit_sync
+ def create_structured_output(
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ ) -> BaseModel:
+ """Generate structured output from AWS Bedrock API (synchronous)."""
+
+ request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
+ return self.client.chat.completions.create(**request_params)
+
+ def show_prompt(self, text_input: str, system_prompt: str) -> str:
+ """Format and display the prompt for a user query."""
+ if not text_input:
+ text_input = "No user input provided."
+ if not system_prompt:
+ raise MissingSystemPromptPathError()
+ system_prompt = LLMGateway.read_query_prompt(system_prompt)
+
+ formatted_prompt = (
+ f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
+ if system_prompt
+ else None
+ )
+ return formatted_prompt
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
index a8fcebbee..208c3729d 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
@@ -80,7 +80,7 @@ class GeminiAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
index 9beb702e5..d6e00d40a 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
@@ -80,7 +80,7 @@ class GenericAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
index 39558f36d..954d85c1d 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
@@ -24,6 +24,7 @@ class LLMProvider(Enum):
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider.
+ - BEDROCK: Represents the AWS Bedrock provider.
"""
OPENAI = "openai"
@@ -32,6 +33,7 @@ class LLMProvider(Enum):
CUSTOM = "custom"
GEMINI = "gemini"
MISTRAL = "mistral"
+ BEDROCK = "bedrock"
def get_llm_client(raise_api_key_error: bool = True):
@@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
)
elif provider == LLMProvider.MISTRAL:
- if llm_config.llm_api_key is None:
+ if llm_config.llm_api_key is None and raise_api_key_error:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
@@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
+ elif provider == LLMProvider.BEDROCK:
+ # if llm_config.llm_api_key is None and raise_api_key_error:
+ # raise LLMAPIKeyNotSetError()
+
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
+ BedrockAdapter,
+ )
+
+ return BedrockAdapter(
+ model=llm_config.llm_model,
+ api_key=llm_config.llm_api_key,
+ max_completion_tokens=max_completion_tokens,
+ streaming=llm_config.llm_streaming,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
+ )
+
else:
raise UnsupportedLLMProviderError(provider)
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
index e9580faeb..e1131524d 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
@@ -69,7 +69,7 @@ class MistralAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from the user query.
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
index 877da23ef..211e49694 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
@@ -76,7 +76,7 @@ class OllamaAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a structured output from the LLM using the provided text and system prompt.
@@ -123,7 +123,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
- async def create_transcript(self, input_file: str) -> str:
+ async def create_transcript(self, input_file: str, **kwargs) -> str:
"""
Generate an audio transcript from a user query.
@@ -162,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
- async def transcribe_image(self, input_file: str) -> str:
+ async def transcribe_image(self, input_file: str, **kwargs) -> str:
"""
Transcribe content from an image using base64 encoding.
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
index 407b720a8..ca9b583b7 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
@@ -112,7 +112,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@@ -154,6 +154,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
+ **kwargs,
)
except (
ContentFilterFinishReasonError,
@@ -180,6 +181,7 @@ class OpenAIAdapter(LLMInterface):
# api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
+ **kwargs,
)
except (
ContentFilterFinishReasonError,
@@ -205,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
def create_structured_output(
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@@ -245,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
+ **kwargs,
)
@retry(
@@ -254,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
- async def create_transcript(self, input):
+ async def create_transcript(self, input, **kwargs):
"""
Generate an audio transcript from a user query.
@@ -281,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint,
api_version=self.api_version,
max_retries=self.MAX_RETRIES,
+ **kwargs,
)
return transcription
@@ -292,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
- async def transcribe_image(self, input) -> BaseModel:
+ async def transcribe_image(self, input, **kwargs) -> BaseModel:
"""
Generate a transcription of an image from a user query.
@@ -337,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
+ **kwargs,
)
diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py
index 645e1a223..22a0fde5f 100644
--- a/cognee/modules/data/deletion/prune_system.py
+++ b/cognee/modules/data/deletion/prune_system.py
@@ -5,6 +5,10 @@ from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
+from cognee.infrastructure.databases.utils import (
+ get_graph_dataset_database_handler,
+ get_vector_dataset_database_handler,
+)
from cognee.shared.cache import delete_cache
from cognee.modules.users.models import DatasetDatabase
from cognee.shared.logging_utils import get_logger
@@ -13,22 +17,13 @@ logger = get_logger()
async def prune_graph_databases():
- async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
- from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
- supported_dataset_database_handlers,
- )
-
- handler = supported_dataset_database_handlers[
- dataset_database.graph_dataset_database_handler
- ]
- return await handler["handler_instance"].delete_dataset(dataset_database)
-
db_engine = get_relational_engine()
try:
- data = await db_engine.get_all_data_from_table("dataset_database")
+ dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the graph database
- for data_item in data:
- await _prune_graph_db(data_item)
+ for dataset_database in dataset_databases:
+ handler = get_graph_dataset_database_handler(dataset_database)
+ await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
@@ -38,22 +33,13 @@ async def prune_graph_databases():
async def prune_vector_databases():
- async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
- from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
- supported_dataset_database_handlers,
- )
-
- handler = supported_dataset_database_handlers[
- dataset_database.vector_dataset_database_handler
- ]
- return await handler["handler_instance"].delete_dataset(dataset_database)
-
db_engine = get_relational_engine()
try:
- data = await db_engine.get_all_data_from_table("dataset_database")
+ dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the vector database
- for data_item in data:
- await _prune_vector_db(data_item)
+ for dataset_database in dataset_databases:
+ handler = get_vector_dataset_database_handler(dataset_database)
+ await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
diff --git a/cognee/modules/data/methods/delete_dataset.py b/cognee/modules/data/methods/delete_dataset.py
index ff20ff9e7..dea10e741 100644
--- a/cognee/modules/data/methods/delete_dataset.py
+++ b/cognee/modules/data/methods/delete_dataset.py
@@ -1,8 +1,34 @@
+from cognee.modules.users.models import DatasetDatabase
+from sqlalchemy import select
+
from cognee.modules.data.models import Dataset
+from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
+ get_vector_dataset_database_handler,
+)
+from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
+ get_graph_dataset_database_handler,
+)
from cognee.infrastructure.databases.relational import get_relational_engine
async def delete_dataset(dataset: Dataset):
db_engine = get_relational_engine()
+ async with db_engine.get_async_session() as session:
+ stmt = select(DatasetDatabase).where(
+ DatasetDatabase.dataset_id == dataset.id,
+ )
+ dataset_database: DatasetDatabase = await session.scalar(stmt)
+ if dataset_database:
+ graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
+ vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
+ await graph_dataset_database_handler["handler_instance"].delete_dataset(
+ dataset_database
+ )
+ await vector_dataset_database_handler["handler_instance"].delete_dataset(
+ dataset_database
+ )
+ # TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
+ # This blocks recreation of the dataset with the same name and data after deletion as
+ # it's marked as completed and will be just skipped even though it's empty.
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py
index d251d113a..b9d006312 100644
--- a/cognee/modules/retrieval/triplet_retriever.py
+++ b/cognee/modules/retrieval/triplet_retriever.py
@@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
- self.top_k = top_k if top_k is not None else 1
+ self.top_k = top_k if top_k is not None else 5
self.system_prompt = system_prompt
async def get_context(self, query: str) -> str:
diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
index bd412e0ca..a70fa661b 100644
--- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py
+++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
@@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
def format_triplets(edges):
- print("\n\n\n")
-
- def filter_attributes(obj, attributes):
- """Helper function to filter out non-None properties, including nested dicts."""
- result = {}
- for attr in attributes:
- value = getattr(obj, attr, None)
- if value is not None:
- # If the value is a dict, extract relevant keys from it
- if isinstance(value, dict):
- nested_values = {
- k: v for k, v in value.items() if k in attributes and v is not None
- }
- result[attr] = nested_values
- else:
- result[attr] = value
- return result
-
triplets = []
for edge in edges:
node1 = edge.node1
diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py
index 071bcca36..4132ba048 100644
--- a/cognee/modules/settings/get_settings.py
+++ b/cognee/modules/settings/get_settings.py
@@ -16,6 +16,7 @@ class ModelName(Enum):
anthropic = "anthropic"
gemini = "gemini"
mistral = "mistral"
+ bedrock = "bedrock"
class LLMConfig(BaseModel):
@@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
"value": "mistral",
"label": "Mistral",
},
+ {
+ "value": "bedrock",
+ "label": "Bedrock",
+ },
]
return SettingsDict.model_validate(
@@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
"label": "Mistral Large 2.1",
},
],
+ "bedrock": [
+ {
+ "value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
+ "label": "Claude 4.5 Sonnet",
+ },
+ {
+ "value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
+ "label": "Claude 4.5 Haiku",
+ },
+ {
+ "value": "eu.amazon.nova-lite-v1:0",
+ "label": "Amazon Nova Lite",
+ },
+ ],
},
},
vector_db={
diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py
index 2d1eca17e..5b762d40c 100644
--- a/cognee/tasks/graph/extract_graph_from_data.py
+++ b/cognee/tasks/graph/extract_graph_from_data.py
@@ -97,6 +97,7 @@ async def extract_graph_from_data(
graph_model: Type[BaseModel],
config: Config = None,
custom_prompt: Optional[str] = None,
+ **kwargs,
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@@ -111,7 +112,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather(
*[
- extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
+ extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
for chunk in data_chunks
]
)
diff --git a/cognee/tasks/memify/get_triplet_datapoints.py b/cognee/tasks/memify/get_triplet_datapoints.py
index bfc02ec6a..764adfb63 100644
--- a/cognee/tasks/memify/get_triplet_datapoints.py
+++ b/cognee/tasks/memify/get_triplet_datapoints.py
@@ -1,5 +1,6 @@
from typing import AsyncGenerator, Dict, Any, List, Optional
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
+from cognee.modules.engine.utils import generate_node_id
from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.infrastructure.engine import DataPoint
@@ -155,7 +156,12 @@ def _process_single_triplet(
embeddable_text = f"{start_node_text}-โบ{relationship_text}-โบ{end_node_text}".strip()
- triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
+ relationship_name = relationship.get("relationship_name", "")
+ triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
+
+ triplet_obj = Triplet(
+ id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
+ )
return triplet_obj, None
diff --git a/cognee/tests/integration/retrieval/test_chunks_retriever.py b/cognee/tests/integration/retrieval/test_chunks_retriever.py
new file mode 100644
index 000000000..d2e5e6149
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_chunks_retriever.py
@@ -0,0 +1,252 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+from typing import List
+import cognee
+
+from cognee.low_level import setup
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.vector import get_vector_engine
+from cognee.modules.chunking.models import DocumentChunk
+from cognee.modules.data.processing.document_types import TextDocument
+from cognee.modules.retrieval.exceptions.exceptions import NoDataError
+from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
+from cognee.infrastructure.engine import DataPoint
+from cognee.modules.data.processing.document_types import Document
+from cognee.modules.engine.models import Entity
+
+
+class DocumentChunkWithEntities(DataPoint):
+ text: str
+ chunk_size: int
+ chunk_index: int
+ cut_type: str
+ is_part_of: Document
+ contains: List[Entity] = None
+
+ metadata: dict = {"index_fields": ["text"]}
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_chunks_simple():
+ """Set up a clean test environment with simple chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
+ data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ document = TextDocument(
+ name="Steve Rodger's career",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_chunks_complex():
+ """Set up a clean test environment with complex chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
+ data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ document1 = TextDocument(
+ name="Employee List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ document2 = TextDocument(
+ name="Car List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+
+ chunk4 = DocumentChunk(
+ text="Range Rover",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk5 = DocumentChunk(
+ text="Hyundai",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk6 = DocumentChunk(
+ text="Chrysler",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
+ data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
+ """Integration test: verify ChunksRetriever can retrieve multiple chunks."""
+ retriever = ChunksRetriever()
+
+ context = await retriever.get_context("Steve")
+
+ assert isinstance(context, list), "Context should be a list"
+ assert len(context) > 0, "Context should not be empty"
+ assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
+ "Failed to get Steve Rodger chunk"
+ )
+
+
+@pytest.mark.asyncio
+async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
+ """Integration test: verify ChunksRetriever respects top_k parameter."""
+ retriever = ChunksRetriever(top_k=2)
+
+ context = await retriever.get_context("Employee")
+
+ assert isinstance(context, list), "Context should be a list"
+ assert len(context) <= 2, "Should respect top_k limit"
+
+
+@pytest.mark.asyncio
+async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
+ """Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
+ retriever = ChunksRetriever(top_k=20)
+
+ context = await retriever.get_context("Christina")
+
+ assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
+
+
+@pytest.mark.asyncio
+async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify ChunksRetriever handles empty graph correctly."""
+ retriever = ChunksRetriever()
+
+ with pytest.raises(NoDataError):
+ await retriever.get_context("Christina Mayer")
+
+ vector_engine = get_vector_engine()
+ await vector_engine.create_collection(
+ "DocumentChunk_text", payload_schema=DocumentChunkWithEntities
+ )
+
+ context = await retriever.get_context("Christina Mayer")
+ assert len(context) == 0, "Found chunks when none should exist"
diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/integration/retrieval/test_completion.py
similarity index 100%
rename from cognee/tests/unit/modules/retrieval/test_completion.py
rename to cognee/tests/integration/retrieval/test_completion.py
diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py
new file mode 100644
index 000000000..7367b353b
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py
@@ -0,0 +1,268 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+from typing import Optional, Union
+import cognee
+
+from cognee.low_level import setup, DataPoint
+from cognee.modules.graph.utils import resolve_edges_to_text
+from cognee.tasks.storage import add_data_points
+from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_simple():
+ """Set up a clean test environment with simple graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple")
+ data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ description: str
+
+ class Person(DataPoint):
+ name: str
+ description: str
+ works_for: Company
+
+ company1 = Company(name="Figma", description="Figma is a company")
+ company2 = Company(name="Canva", description="Canvas is a company")
+ person1 = Person(
+ name="Steve Rodger",
+ description="This is description about Steve Rodger",
+ works_for=company1,
+ )
+ person2 = Person(
+ name="Ike Loma", description="This is description about Ike Loma", works_for=company1
+ )
+ person3 = Person(
+ name="Jason Statham",
+ description="This is description about Jason Statham",
+ works_for=company1,
+ )
+ person4 = Person(
+ name="Mike Broski",
+ description="This is description about Mike Broski",
+ works_for=company2,
+ )
+ person5 = Person(
+ name="Christina Mayer",
+ description="This is description about Christina Mayer",
+ works_for=company2,
+ )
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_complex():
+ """Set up a clean test environment with complex graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex")
+ data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_context_simple(setup_test_environment_simple):
+ """Integration test: verify GraphCompletionRetriever can retrieve context (simple)."""
+ retriever = GraphCompletionRetriever()
+
+ context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
+
+ # Ensure the top-level sections are present
+ assert "Nodes:" in context, "Missing 'Nodes:' section in context"
+ assert "Connections:" in context, "Missing 'Connections:' section in context"
+
+ # --- Nodes headers ---
+ assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
+ assert "Node: Figma" in context, "Missing node header for Figma"
+ assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
+ assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
+ assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
+ assert "Node: Canva" in context, "Missing node header for Canva"
+ assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
+
+ # --- Node contents ---
+ assert (
+ "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
+ in context
+ ), "Description block for Steve Rodger altered"
+ assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
+ "Description block for Figma altered"
+ )
+ assert (
+ "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
+ in context
+ ), "Description block for Ike Loma altered"
+ assert (
+ "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
+ in context
+ ), "Description block for Jason Statham altered"
+ assert (
+ "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
+ in context
+ ), "Description block for Mike Broski altered"
+ assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
+ "Description block for Canva altered"
+ )
+ assert (
+ "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
+ in context
+ ), "Description block for Christina Mayer altered"
+
+ # --- Connections ---
+ assert "Steve Rodger --[works_for]--> Figma" in context, (
+ "Connection Steve RodgerโFigma missing or changed"
+ )
+ assert "Ike Loma --[works_for]--> Figma" in context, (
+ "Connection Ike LomaโFigma missing or changed"
+ )
+ assert "Jason Statham --[works_for]--> Figma" in context, (
+ "Connection Jason StathamโFigma missing or changed"
+ )
+ assert "Mike Broski --[works_for]--> Canva" in context, (
+ "Connection Mike BroskiโCanva missing or changed"
+ )
+ assert "Christina Mayer --[works_for]--> Canva" in context, (
+ "Connection Christina MayerโCanva missing or changed"
+ )
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_context_complex(setup_test_environment_complex):
+ """Integration test: verify GraphCompletionRetriever can retrieve context (complex)."""
+ retriever = GraphCompletionRetriever(top_k=20)
+
+ context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+
+@pytest.mark.asyncio
+async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionRetriever handles empty graph correctly."""
+ retriever = GraphCompletionRetriever()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == [], "Context should be empty on an empty graph"
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_get_triplets_empty(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionRetriever get_triplets handles empty graph."""
+ retriever = GraphCompletionRetriever()
+
+ triplets = await retriever.get_triplets("Who works at Figma?")
+
+ assert isinstance(triplets, list), "Triplets should be a list"
+ assert len(triplets) == 0, "Should return empty list on empty graph"
diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py
new file mode 100644
index 000000000..c87de16ef
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py
@@ -0,0 +1,226 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+from typing import Optional, Union
+import cognee
+
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.modules.graph.utils import resolve_edges_to_text
+from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
+ GraphCompletionContextExtensionRetriever,
+)
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_simple():
+ """Set up a clean test environment with simple graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_graph_completion_extension_context_simple"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_graph_completion_extension_context_simple"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+ person1 = Person(name="Steve Rodger", works_for=company1)
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person3 = Person(name="Jason Statham", works_for=company1)
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person5 = Person(name="Christina Mayer", works_for=company2)
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_complex():
+ """Set up a clean test environment with complex graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_graph_completion_extension_context_complex"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_graph_completion_extension_context_complex"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_extension_context_simple(setup_test_environment_simple):
+ """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple)."""
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
+
+ assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
+ assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
+
+ answer = await retriever.get_completion("Who works at Canva?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_extension_context_complex(setup_test_environment_complex):
+ """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex)."""
+ retriever = GraphCompletionContextExtensionRetriever(top_k=20)
+
+ context = await resolve_edges_to_text(
+ await retriever.get_context("Who works at Figma and drives Tesla?")
+ )
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly."""
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == [], "Context should be empty on an empty graph"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph."""
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ triplets = await retriever.get_triplets("Who works at Figma?")
+
+ assert isinstance(triplets, list), "Triplets should be a list"
+ assert len(triplets) == 0, "Should return empty list on empty graph"
diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py
new file mode 100644
index 000000000..0db035e03
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py
@@ -0,0 +1,218 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+from typing import Optional, Union
+import cognee
+
+from cognee.low_level import setup, DataPoint
+from cognee.modules.graph.utils import resolve_edges_to_text
+from cognee.tasks.storage import add_data_points
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_simple():
+ """Set up a clean test environment with simple graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_graph_completion_cot_context_simple"
+ )
+ data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+ person1 = Person(name="Steve Rodger", works_for=company1)
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person3 = Person(name="Jason Statham", works_for=company1)
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person5 = Person(name="Christina Mayer", works_for=company2)
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_complex():
+ """Set up a clean test environment with complex graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_graph_completion_cot_context_complex"
+ )
+ data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without graph data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_cot_context_simple(setup_test_environment_simple):
+ """Integration test: verify GraphCompletionCotRetriever can retrieve context (simple)."""
+ retriever = GraphCompletionCotRetriever()
+
+ context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
+
+ assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
+ assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
+
+ answer = await retriever.get_completion("Who works at Canva?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_cot_context_complex(setup_test_environment_complex):
+ """Integration test: verify GraphCompletionCotRetriever can retrieve context (complex)."""
+ retriever = GraphCompletionCotRetriever(top_k=20)
+
+ context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionCotRetriever handles empty graph correctly."""
+ retriever = GraphCompletionCotRetriever()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == [], "Context should be empty on an empty graph"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty):
+ """Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph."""
+ retriever = GraphCompletionCotRetriever()
+
+ triplets = await retriever.get_triplets("Who works at Figma?")
+
+ assert isinstance(triplets, list), "Triplets should be a list"
+ assert len(triplets) == 0, "Should return empty list on empty graph"
diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py
similarity index 100%
rename from cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py
rename to cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py
diff --git a/cognee/tests/integration/retrieval/test_rag_completion_retriever.py b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py
new file mode 100644
index 000000000..b01d58160
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py
@@ -0,0 +1,254 @@
+import os
+from typing import List
+import pytest
+import pathlib
+import pytest_asyncio
+import cognee
+
+from cognee.low_level import setup
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.vector import get_vector_engine
+from cognee.modules.chunking.models import DocumentChunk
+from cognee.modules.data.processing.document_types import TextDocument
+from cognee.modules.retrieval.exceptions.exceptions import NoDataError
+from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+from cognee.infrastructure.engine import DataPoint
+from cognee.modules.data.processing.document_types import Document
+from cognee.modules.engine.models import Entity
+
+
+class DocumentChunkWithEntities(DataPoint):
+ text: str
+ chunk_size: int
+ chunk_index: int
+ cut_type: str
+ is_part_of: Document
+ contains: List[Entity] = None
+
+ metadata: dict = {"index_fields": ["text"]}
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_chunks_simple():
+ """Set up a clean test environment with simple chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple")
+ data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ document = TextDocument(
+ name="Steve Rodger's career",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_chunks_complex():
+ """Set up a clean test environment with complex chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex")
+ data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ document1 = TextDocument(
+ name="Employee List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ document2 = TextDocument(
+ name="Car List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+
+ chunk4 = DocumentChunk(
+ text="Range Rover",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk5 = DocumentChunk(
+ text="Hyundai",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk6 = DocumentChunk(
+ text="Chrysler",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without chunks."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(
+ base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph"
+ )
+ data_directory_path = str(
+ base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph"
+ )
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple):
+ """Integration test: verify CompletionRetriever can retrieve context (simple)."""
+ retriever = CompletionRetriever()
+
+ context = await retriever.get_context("Mike")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert "Mike Broski" in context, "Failed to get Mike Broski"
+
+
+@pytest.mark.asyncio
+async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple):
+ """Integration test: verify CompletionRetriever can retrieve context from multiple chunks."""
+ retriever = CompletionRetriever()
+
+ context = await retriever.get_context("Steve")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert "Steve Rodger" in context, "Failed to get Steve Rodger"
+
+
+@pytest.mark.asyncio
+async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex):
+ """Integration test: verify CompletionRetriever can retrieve context (complex)."""
+ # TODO: top_k doesn't affect the output, it should be fixed.
+ retriever = CompletionRetriever(top_k=20)
+
+ context = await retriever.get_context("Christina")
+
+ assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
+
+
+@pytest.mark.asyncio
+async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify CompletionRetriever handles empty graph correctly."""
+ retriever = CompletionRetriever()
+
+ with pytest.raises(NoDataError):
+ await retriever.get_context("Christina Mayer")
+
+ vector_engine = get_vector_engine()
+ await vector_engine.create_collection(
+ "DocumentChunk_text", payload_schema=DocumentChunkWithEntities
+ )
+
+ context = await retriever.get_context("Christina Mayer")
+ assert context == "", "Returned context should be empty on an empty graph"
diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/integration/retrieval/test_structured_output.py
similarity index 65%
rename from cognee/tests/unit/modules/retrieval/structured_output_test.py
rename to cognee/tests/integration/retrieval/test_structured_output.py
index 4ad3019ff..13ffd8eef 100644
--- a/cognee/tests/unit/modules/retrieval/structured_output_test.py
+++ b/cognee/tests/integration/retrieval/test_structured_output.py
@@ -1,9 +1,9 @@
import asyncio
-
-import pytest
-import cognee
-import pathlib
import os
+import pytest
+import pathlib
+import pytest_asyncio
+import cognee
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
@@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion():
_assert_structured_answer(structured_answer)
-class TestStructuredOutputCompletion:
- @pytest.mark.asyncio
- async def test_get_structured_completion(self):
- system_directory_path = os.path.join(
- pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
- )
- cognee.config.system_root_directory(system_directory_path)
- data_directory_path = os.path.join(
- pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
- )
- cognee.config.data_root_directory(data_directory_path)
+@pytest_asyncio.fixture
+async def setup_test_environment():
+ """Set up a clean test environment with graph and document data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion")
+ data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion")
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ works_since: int
+
+ company1 = Company(name="Figma")
+ person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
+
+ entities = [company1, person1]
+ await add_data_points(entities)
+
+ document = TextDocument(
+ name="Steve Rodger's career",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3]
+ await add_data_points(entities)
+
+ entity_type = EntityType(name="Person", description="A human individual")
+ entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
+
+ entities = [entity]
+ await add_data_points(entities)
+
+ yield
+
+ try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
- await setup()
+ except Exception:
+ pass
- class Company(DataPoint):
- name: str
- class Person(DataPoint):
- name: str
- works_for: Company
- works_since: int
-
- company1 = Company(name="Figma")
- person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
-
- entities = [company1, person1]
- await add_data_points(entities)
-
- document = TextDocument(
- name="Steve Rodger's career",
- raw_data_location="somewhere",
- external_metadata="",
- mime_type="text/plain",
- )
-
- chunk1 = DocumentChunk(
- text="Steve Rodger",
- chunk_size=2,
- chunk_index=0,
- cut_type="sentence_end",
- is_part_of=document,
- contains=[],
- )
- chunk2 = DocumentChunk(
- text="Mike Broski",
- chunk_size=2,
- chunk_index=1,
- cut_type="sentence_end",
- is_part_of=document,
- contains=[],
- )
- chunk3 = DocumentChunk(
- text="Christina Mayer",
- chunk_size=2,
- chunk_index=2,
- cut_type="sentence_end",
- is_part_of=document,
- contains=[],
- )
-
- entities = [chunk1, chunk2, chunk3]
- await add_data_points(entities)
-
- entity_type = EntityType(name="Person", description="A human individual")
- entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
-
- entities = [entity]
- await add_data_points(entities)
-
- await _test_get_structured_graph_completion_cot()
- await _test_get_structured_graph_completion()
- await _test_get_structured_graph_completion_temporal()
- await _test_get_structured_graph_completion_rag()
- await _test_get_structured_graph_completion_context_extension()
- await _test_get_structured_entity_completion()
+@pytest.mark.asyncio
+async def test_get_structured_completion(setup_test_environment):
+ """Integration test: verify structured output completion for all retrievers."""
+ await _test_get_structured_graph_completion_cot()
+ await _test_get_structured_graph_completion()
+ await _test_get_structured_graph_completion_temporal()
+ await _test_get_structured_graph_completion_rag()
+ await _test_get_structured_graph_completion_context_extension()
+ await _test_get_structured_entity_completion()
diff --git a/cognee/tests/integration/retrieval/test_summaries_retriever.py b/cognee/tests/integration/retrieval/test_summaries_retriever.py
new file mode 100644
index 000000000..a2f4e40b3
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_summaries_retriever.py
@@ -0,0 +1,184 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+import cognee
+
+from cognee.low_level import setup
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.vector import get_vector_engine
+from cognee.modules.chunking.models import DocumentChunk
+from cognee.tasks.summarization.models import TextSummary
+from cognee.modules.data.processing.document_types import TextDocument
+from cognee.modules.retrieval.exceptions.exceptions import NoDataError
+from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_summaries():
+ """Set up a clean test environment with summaries."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context")
+ data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ document1 = TextDocument(
+ name="Employee List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ document2 = TextDocument(
+ name="Car List",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk1_summary = TextSummary(
+ text="S.R.",
+ made_from=chunk1,
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk2_summary = TextSummary(
+ text="M.B.",
+ made_from=chunk2,
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document1,
+ contains=[],
+ )
+ chunk3_summary = TextSummary(
+ text="C.M.",
+ made_from=chunk3,
+ )
+ chunk4 = DocumentChunk(
+ text="Range Rover",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk4_summary = TextSummary(
+ text="R.R.",
+ made_from=chunk4,
+ )
+ chunk5 = DocumentChunk(
+ text="Hyundai",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk5_summary = TextSummary(
+ text="H.Y.",
+ made_from=chunk5,
+ )
+ chunk6 = DocumentChunk(
+ text="Chrysler",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document2,
+ contains=[],
+ )
+ chunk6_summary = TextSummary(
+ text="C.H.",
+ made_from=chunk6,
+ )
+
+ entities = [
+ chunk1_summary,
+ chunk2_summary,
+ chunk3_summary,
+ chunk4_summary,
+ chunk5_summary,
+ chunk6_summary,
+ ]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without summaries."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty")
+ data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_summaries_retriever_context(setup_test_environment_with_summaries):
+ """Integration test: verify SummariesRetriever can retrieve summary context."""
+ retriever = SummariesRetriever(top_k=20)
+
+ context = await retriever.get_context("Christina")
+
+ assert isinstance(context, list), "Context should be a list"
+ assert len(context) > 0, "Context should not be empty"
+ assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
+
+
+@pytest.mark.asyncio
+async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty):
+ """Integration test: verify SummariesRetriever handles empty graph correctly."""
+ retriever = SummariesRetriever()
+
+ with pytest.raises(NoDataError):
+ await retriever.get_context("Christina Mayer")
+
+ vector_engine = get_vector_engine()
+ await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
+
+ context = await retriever.get_context("Christina Mayer")
+ assert context == [], "Returned context should be empty on an empty graph"
diff --git a/cognee/tests/integration/retrieval/test_temporal_retriever.py b/cognee/tests/integration/retrieval/test_temporal_retriever.py
new file mode 100644
index 000000000..8ce3b32f4
--- /dev/null
+++ b/cognee/tests/integration/retrieval/test_temporal_retriever.py
@@ -0,0 +1,306 @@
+import os
+import pytest
+import pathlib
+import pytest_asyncio
+import cognee
+
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
+from cognee.modules.engine.models.Event import Event
+from cognee.modules.engine.models.Timestamp import Timestamp
+from cognee.modules.engine.models.Interval import Interval
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_events():
+ """Set up a clean test environment with temporal events."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
+ data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ # Create timestamps for events
+ timestamp1 = Timestamp(
+ time_at=1609459200, # 2021-01-01 00:00:00
+ year=2021,
+ month=1,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ timestamp_str="2021-01-01T00:00:00",
+ )
+
+ timestamp2 = Timestamp(
+ time_at=1612137600, # 2021-02-01 00:00:00
+ year=2021,
+ month=2,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ timestamp_str="2021-02-01T00:00:00",
+ )
+
+ timestamp3 = Timestamp(
+ time_at=1614556800, # 2021-03-01 00:00:00
+ year=2021,
+ month=3,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ timestamp_str="2021-03-01T00:00:00",
+ )
+
+ timestamp4 = Timestamp(
+ time_at=1625097600, # 2021-07-01 00:00:00
+ year=2021,
+ month=7,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ timestamp_str="2021-07-01T00:00:00",
+ )
+
+ timestamp5 = Timestamp(
+ time_at=1633046400, # 2021-10-01 00:00:00
+ year=2021,
+ month=10,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ timestamp_str="2021-10-01T00:00:00",
+ )
+
+ # Create interval for event spanning multiple timestamps
+ interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
+
+ # Create events with timestamps
+ event1 = Event(
+ name="Project Alpha Launch",
+ description="Launched Project Alpha at the beginning of 2021",
+ at=timestamp1,
+ location="San Francisco",
+ )
+
+ event2 = Event(
+ name="Team Meeting",
+ description="Monthly team meeting discussing Q1 goals",
+ during=interval1,
+ location="New York",
+ )
+
+ event3 = Event(
+ name="Product Release",
+ description="Released new product features in July",
+ at=timestamp4,
+ location="Remote",
+ )
+
+ event4 = Event(
+ name="Company Retreat",
+ description="Annual company retreat in October",
+ at=timestamp5,
+ location="Lake Tahoe",
+ )
+
+ entities = [event1, event2, event3, event4]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_with_graph_data():
+ """Set up a clean test environment with graph data (for fallback to triplets)."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
+ data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ description: str
+
+ class Person(DataPoint):
+ name: str
+ description: str
+ works_for: Company
+
+ company1 = Company(name="Figma", description="Figma is a company")
+ person1 = Person(
+ name="Steve Rodger",
+ description="This is description about Steve Rodger",
+ works_for=company1,
+ )
+
+ entities = [company1, person1]
+
+ await add_data_points(entities)
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_test_environment_empty():
+ """Set up a clean test environment without data."""
+ base_dir = pathlib.Path(__file__).parent.parent.parent.parent
+ system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
+ data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
+
+ cognee.config.system_root_directory(system_directory_path)
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ yield
+
+ try:
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ except Exception:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
+ """Integration test: verify TemporalRetriever can retrieve events within time range."""
+ retriever = TemporalRetriever(top_k=5)
+
+ context = await retriever.get_context("What happened in January 2021?")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) > 0, "Context should not be empty"
+ assert "Project Alpha" in context or "Launch" in context, (
+ "Should retrieve Project Alpha Launch event from January 2021"
+ )
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
+ """Integration test: verify TemporalRetriever can retrieve events at specific time."""
+ retriever = TemporalRetriever(top_k=5)
+
+ context = await retriever.get_context("What happened in July 2021?")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) > 0, "Context should not be empty"
+ assert "Product Release" in context or "July" in context, (
+ "Should retrieve Product Release event from July 2021"
+ )
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_context_fallback_to_triplets(
+ setup_test_environment_with_graph_data,
+):
+ """Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
+ retriever = TemporalRetriever(top_k=5)
+
+ context = await retriever.get_context("Who works at Figma?")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) > 0, "Context should not be empty"
+ assert "Steve" in context or "Figma" in context, (
+ "Should retrieve graph data via triplet search fallback"
+ )
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
+ """Integration test: verify TemporalRetriever handles empty graph correctly."""
+ retriever = TemporalRetriever()
+
+ context = await retriever.get_context("What happened?")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) >= 0, "Context should be a string (possibly empty)"
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
+ """Integration test: verify TemporalRetriever can generate completions."""
+ retriever = TemporalRetriever()
+
+ completion = await retriever.get_completion("What happened in January 2021?")
+
+ assert isinstance(completion, list), "Completion should be a list"
+ assert len(completion) > 0, "Completion should not be empty"
+ assert all(isinstance(item, str) and item.strip() for item in completion), (
+ "Completion items should be non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
+ """Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
+ retriever = TemporalRetriever()
+
+ completion = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(completion, list), "Completion should be a list"
+ assert len(completion) > 0, "Completion should not be empty"
+ assert all(isinstance(item, str) and item.strip() for item in completion), (
+ "Completion items should be non-empty strings"
+ )
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
+ """Integration test: verify TemporalRetriever respects top_k parameter."""
+ retriever = TemporalRetriever(top_k=2)
+
+ context = await retriever.get_context("What happened in 2021?")
+
+ assert isinstance(context, str), "Context should be a string"
+ separator_count = context.count("#####################")
+ assert separator_count <= 1, "Should respect top_k limit of 2 events"
+
+
+@pytest.mark.asyncio
+async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
+ """Integration test: verify TemporalRetriever can retrieve multiple events."""
+ retriever = TemporalRetriever(top_k=10)
+
+ context = await retriever.get_context("What events occurred in 2021?")
+
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) > 0, "Context should not be empty"
+
+ assert (
+ "Project Alpha" in context
+ or "Team Meeting" in context
+ or "Product Release" in context
+ or "Company Retreat" in context
+ ), "Should retrieve at least one event from 2021"
diff --git a/cognee/tests/integration/retrieval/test_triplet_retriever.py b/cognee/tests/integration/retrieval/test_triplet_retriever.py
index e547b6cbe..ebe853e08 100644
--- a/cognee/tests/integration/retrieval/test_triplet_retriever.py
+++ b/cognee/tests/integration/retrieval/test_triplet_retriever.py
@@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip
context = await retriever.get_context("Alice")
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
+ assert isinstance(context, str), "Context should be a string"
+ assert len(context) > 0, "Context should not be empty"
+
+
+@pytest.mark.asyncio
+async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets):
+ """Integration test: verify TripletRetriever can retrieve multiple triplets."""
+ retriever = TripletRetriever(top_k=5)
+
+ context = await retriever.get_context("Bob")
+
+ assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, (
+ "Failed to get Bob-related triplets"
+ )
+
+
+@pytest.mark.asyncio
+async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets):
+ """Integration test: verify TripletRetriever respects top_k parameter."""
+ retriever = TripletRetriever(top_k=1)
+
+ context = await retriever.get_context("Alice")
+
+ assert isinstance(context, str), "Context should be a string"
+
+
+@pytest.mark.asyncio
+async def test_triplet_retriever_context_empty(setup_test_environment_empty):
+ """Integration test: verify TripletRetriever handles empty graph correctly."""
+ await setup()
+
+ retriever = TripletRetriever()
+
+ with pytest.raises(NoDataError):
+ await retriever.get_context("Alice")
diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/integration/retrieval/test_user_qa_feedback.py
similarity index 100%
rename from cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py
rename to cognee/tests/integration/retrieval/test_user_qa_feedback.py
diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py
index fece88240..a626088a3 100644
--- a/cognee/tests/test_cognee_server_start.py
+++ b/cognee/tests/test_cognee_server_start.py
@@ -148,8 +148,8 @@ class TestCogneeServerStart(unittest.TestCase):
headers=headers,
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={
- "ontology_key": json.dumps([ontology_key]),
- "description": json.dumps(["Test ontology"]),
+ "ontology_key": ontology_key,
+ "description": "Test ontology",
},
)
self.assertEqual(ontology_response.status_code, 200)
diff --git a/cognee/tests/test_dataset_delete.py b/cognee/tests/test_dataset_delete.py
new file mode 100644
index 000000000..372945bdb
--- /dev/null
+++ b/cognee/tests/test_dataset_delete.py
@@ -0,0 +1,76 @@
+import os
+import asyncio
+import pathlib
+from uuid import UUID
+
+import cognee
+from cognee.shared.logging_utils import setup_logging, ERROR
+from cognee.modules.data.methods.delete_dataset import delete_dataset
+from cognee.modules.data.methods.get_dataset import get_dataset
+from cognee.modules.users.methods import get_default_user
+
+
+async def main():
+ # Set data and system directory paths
+ data_directory_path = str(
+ pathlib.Path(
+ os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_dataset_delete")
+ ).resolve()
+ )
+ cognee.config.data_root_directory(data_directory_path)
+ cognee_directory_path = str(
+ pathlib.Path(
+ os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_dataset_delete")
+ ).resolve()
+ )
+ cognee.config.system_root_directory(cognee_directory_path)
+
+ # Create a clean slate for cognee -- reset data and system state
+ print("Resetting cognee data...")
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ print("Data reset complete.\n")
+
+ # cognee knowledge graph will be created based on this text
+ text = """
+ Natural language processing (NLP) is an interdisciplinary
+ subfield of computer science and information retrieval.
+ """
+
+ # Add the text, and make it available for cognify
+ await cognee.add(text, "nlp_dataset")
+ await cognee.add("Quantum computing is the study of quantum computers.", "quantum_dataset")
+
+ # Use LLMs and cognee to create knowledge graph
+ ret_val = await cognee.cognify()
+ user = await get_default_user()
+
+ for val in ret_val:
+ dataset_id = str(val)
+ vector_db_path = os.path.join(
+ cognee_directory_path, "databases", str(user.id), dataset_id + ".lance.db"
+ )
+ graph_db_path = os.path.join(
+ cognee_directory_path, "databases", str(user.id), dataset_id + ".pkl"
+ )
+
+ # Check if databases are properly created and exist before deletion
+ assert os.path.exists(graph_db_path), "Graph database file not found."
+ assert os.path.exists(vector_db_path), "Vector database file not found."
+
+ dataset = await get_dataset(user_id=user.id, dataset_id=UUID(dataset_id))
+ await delete_dataset(dataset)
+
+ # Confirm databases have been deleted
+ assert not os.path.exists(graph_db_path), "Graph database file found."
+ assert not os.path.exists(vector_db_path), "Vector database file found."
+
+
+if __name__ == "__main__":
+ logger = setup_logging(log_level=ERROR)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(main())
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py
index ba150f813..0916be322 100644
--- a/cognee/tests/test_search_db.py
+++ b/cognee/tests/test_search_db.py
@@ -1,5 +1,10 @@
import pathlib
import os
+import asyncio
+import pytest
+import pytest_asyncio
+from collections import Counter
+
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
+from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
+from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
+from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
-from collections import Counter
logger = get_logger()
-async def main():
- # This test runs for multiple db settings, to run this locally set the corresponding db envs
+async def _reset_engines_and_prune() -> None:
+ """Reset db engine caches and prune data/system.
+
+ Kept intentionally identical to the inlined setup logic to avoid event loop issues when
+ using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
+ """
+ # Dispose of existing engines and clear caches to ensure fresh instances for each test
+ try:
+ from cognee.infrastructure.databases.vector import get_vector_engine
+
+ vector_engine = get_vector_engine()
+ # Dispose SQLAlchemy engine connection pool if it exists
+ if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
+ await vector_engine.engine.dispose(close=True)
+ except Exception:
+ # Engine might not exist yet
+ pass
+
+ from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
+ from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
+ from cognee.infrastructure.databases.relational.create_relational_engine import (
+ create_relational_engine,
+ )
+
+ create_graph_engine.cache_clear()
+ create_vector_engine.cache_clear()
+ create_relational_engine.cache_clear()
+
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
- dataset_name = "test_dataset"
+async def _seed_default_dataset(dataset_name: str) -> dict:
+ """Add the shared test dataset contents and run cognify (same steps/order as before)."""
text_1 = """Germany is located in europe right next to the Netherlands"""
+
+ logger.info(f"Adding text data to dataset: {dataset_name}")
await cognee.add(text_1, dataset_name)
explanation_file_path_quantum = os.path.join(
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
)
+ logger.info(f"Adding file data to dataset: {dataset_name}")
await cognee.add([explanation_file_path_quantum], dataset_name)
+ logger.info(f"Running cognify on dataset: {dataset_name}")
await cognee.cognify([dataset_name])
+ return {
+ "dataset_name": dataset_name,
+ "text_1": text_1,
+ "explanation_file_path_quantum": explanation_file_path_quantum,
+ }
+
+
+@pytest.fixture(scope="session")
+def event_loop():
+ """Use a single asyncio event loop for this test module.
+
+ This helps avoid "Future attached to a different loop" when running multiple async
+ tests that share clients/engines.
+ """
+ loop = asyncio.new_event_loop()
+ try:
+ yield loop
+ finally:
+ loop.close()
+
+
+async def setup_test_environment():
+ """Helper function to set up test environment with data, cognify, and triplet embeddings."""
+ # This test runs for multiple db settings, to run this locally set the corresponding db envs
+
+ dataset_name = "test_dataset"
+ logger.info("Starting test setup: pruning data and system")
+ await _reset_engines_and_prune()
+ state = await _seed_default_dataset(dataset_name=dataset_name)
+
user = await get_default_user()
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
+ logger.info("Creating triplet embeddings")
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
+ # Check if Triplet_text collection was created
+ vector_engine = get_vector_engine()
+ has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
+ logger.info(f"Triplet_text collection exists after creation: {has_collection}")
+
+ if has_collection:
+ collection = await vector_engine.get_collection("Triplet_text")
+ count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
+ logger.info(f"Triplet_text collection row count: {count}")
+
+ return state
+
+
+async def setup_test_environment_for_feedback():
+ """Helper function to set up test environment for feedback weight calculation test."""
+ dataset_name = "test_dataset"
+ await _reset_engines_and_prune()
+ return await _seed_default_dataset(dataset_name=dataset_name)
+
+
+@pytest_asyncio.fixture(scope="session")
+async def e2e_state():
+ """Compute E2E artifacts once; tests only assert.
+
+ This avoids repeating expensive setup and LLM calls across multiple tests.
+ """
+ await setup_test_environment()
+
+ # --- Graph/vector engine consistency ---
graph_engine = await get_graph_engine()
- nodes, edges = await graph_engine.get_graph_data()
+ _nodes, edges = await graph_engine.get_graph_data()
vector_engine = get_vector_engine()
collection = await vector_engine.search(
- query_text="Test", limit=None, collection_name="Triplet_text"
+ collection_name="Triplet_text", query_text="Test", limit=None
)
- assert len(edges) == len(collection), (
- f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
- )
+ # --- Retriever contexts ---
+ query = "Next to which country is Germany located?"
- context_gk = await GraphCompletionRetriever().get_context(
- query="Next to which country is Germany located?"
- )
- context_gk_cot = await GraphCompletionCotRetriever().get_context(
- query="Next to which country is Germany located?"
- )
- context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
- query="Next to which country is Germany located?"
- )
- context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
- query="Next to which country is Germany located?"
- )
- context_triplet = await TripletRetriever().get_context(
- query="Next to which country is Germany located?"
- )
+ contexts = {
+ "graph_completion": await GraphCompletionRetriever().get_context(query=query),
+ "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
+ "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
+ query=query
+ ),
+ "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
+ query=query
+ ),
+ "chunks": await ChunksRetriever(top_k=5).get_context(query=query),
+ "summaries": await SummariesRetriever(top_k=5).get_context(query=query),
+ "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
+ "temporal": await TemporalRetriever(top_k=5).get_context(query=query),
+ "triplet": await TripletRetriever().get_context(query=query),
+ }
- for name, context in [
- ("GraphCompletionRetriever", context_gk),
- ("GraphCompletionCotRetriever", context_gk_cot),
- ("GraphCompletionContextExtensionRetriever", context_gk_ext),
- ("GraphSummaryCompletionRetriever", context_gk_sum),
- ]:
- assert isinstance(context, list), f"{name}: Context should be a list"
- assert len(context) > 0, f"{name}: Context should not be empty"
-
- context_text = await resolve_edges_to_text(context)
- lower = context_text.lower()
- assert "germany" in lower or "netherlands" in lower, (
- f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
- )
-
- assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
- assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
- lower_triplet = context_triplet.lower()
- assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
- f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
- )
-
- triplets_gk = await GraphCompletionRetriever().get_triplets(
- query="Next to which country is Germany located?"
- )
- triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
- query="Next to which country is Germany located?"
- )
- triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
- query="Next to which country is Germany located?"
- )
- triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
- query="Next to which country is Germany located?"
- )
-
- for name, triplets in [
- ("GraphCompletionRetriever", triplets_gk),
- ("GraphCompletionCotRetriever", triplets_gk_cot),
- ("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
- ("GraphSummaryCompletionRetriever", triplets_gk_sum),
- ]:
- assert isinstance(triplets, list), f"{name}: Triplets should be a list"
- assert triplets, f"{name}: Triplets list should not be empty"
- for edge in triplets:
- assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
- distance = edge.attributes.get("vector_distance")
- node1_distance = edge.node1.attributes.get("vector_distance")
- node2_distance = edge.node2.attributes.get("vector_distance")
- assert isinstance(distance, float), (
- f"{name}: vector_distance should be float, got {type(distance)}"
- )
- assert 0 <= distance <= 1, (
- f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
- )
- assert 0 <= node1_distance <= 1, (
- f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
- )
- assert 0 <= node2_distance <= 1, (
- f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
- )
+ # --- Retriever triplets + vector distance validation ---
+ triplets = {
+ "graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
+ "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
+ "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
+ query=query
+ ),
+ "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
+ query=query
+ ),
+ }
+ # --- Search operations + graph side effects ---
completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Where is germany located, next to which country?",
@@ -164,6 +214,26 @@ async def main():
query_text="Next to which country is Germany located?",
save_interaction=True,
)
+ completion_chunks = await cognee.search(
+ query_type=SearchType.CHUNKS,
+ query_text="Germany",
+ save_interaction=False,
+ )
+ completion_summaries = await cognee.search(
+ query_type=SearchType.SUMMARIES,
+ query_text="Germany",
+ save_interaction=False,
+ )
+ completion_rag = await cognee.search(
+ query_type=SearchType.RAG_COMPLETION,
+ query_text="Next to which country is Germany located?",
+ save_interaction=False,
+ )
+ completion_temporal = await cognee.search(
+ query_type=SearchType.TEMPORAL,
+ query_text="Next to which country is Germany located?",
+ save_interaction=False,
+ )
await cognee.search(
query_type=SearchType.FEEDBACK,
@@ -171,134 +241,217 @@ async def main():
last_k=1,
)
- for name, search_results in [
- ("GRAPH_COMPLETION", completion_gk),
- ("GRAPH_COMPLETION_COT", completion_cot),
- ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
- ("GRAPH_SUMMARY_COMPLETION", completion_sum),
- ("TRIPLET_COMPLETION", completion_triplet),
- ]:
- assert isinstance(search_results, list), f"{name}: should return a list"
- assert len(search_results) == 1, (
- f"{name}: expected single-element list, got {len(search_results)}"
- )
+ # Snapshot after all E2E operations above (used by assertion-only tests).
+ graph_snapshot = await (await get_graph_engine()).get_graph_data()
- from cognee.context_global_variables import backend_access_control_enabled
+ return {
+ "graph_edges": edges,
+ "triplet_collection": collection,
+ "vector_collection_edges_count": len(collection),
+ "graph_edges_count": len(edges),
+ "contexts": contexts,
+ "triplets": triplets,
+ "search_results": {
+ "graph_completion": completion_gk,
+ "graph_completion_cot": completion_cot,
+ "graph_completion_context_extension": completion_ext,
+ "graph_summary_completion": completion_sum,
+ "triplet_completion": completion_triplet,
+ "chunks": completion_chunks,
+ "summaries": completion_summaries,
+ "rag_completion": completion_rag,
+ "temporal": completion_temporal,
+ },
+ "graph_snapshot": graph_snapshot,
+ }
- if backend_access_control_enabled():
- text = search_results[0]["search_result"][0]
- else:
- text = search_results[0]
- assert isinstance(text, str), f"{name}: element should be a string"
- assert text.strip(), f"{name}: string should not be empty"
- assert "netherlands" in text.lower(), (
- f"{name}: expected 'netherlands' in result, got: {text!r}"
- )
- graph_engine = await get_graph_engine()
- graph = await graph_engine.get_graph_data()
-
- type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
-
- edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
-
- # Assert there are exactly 4 CogneeUserInteraction nodes.
- assert type_counts.get("CogneeUserInteraction", 0) == 4, (
- f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
- )
-
- # Assert there is exactly two CogneeUserFeedback nodes.
- assert type_counts.get("CogneeUserFeedback", 0) == 2, (
- f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
- )
-
- # Assert there is exactly two NodeSet.
- assert type_counts.get("NodeSet", 0) == 2, (
- f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
- )
-
- # Assert that there are at least 10 'used_graph_element_to_answer' edges.
- assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
- f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
- )
-
- # Assert that there are exactly 2 'gives_feedback_to' edges.
- assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
- f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
- )
-
- # Assert that there are at least 6 'belongs_to_set' edges.
- assert edge_type_counts.get("belongs_to_set", 0) == 6, (
- f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
- )
-
- nodes = graph[0]
-
- required_fields_user_interaction = {"question", "answer", "context"}
- required_fields_feedback = {"feedback", "sentiment"}
-
- for node_id, data in nodes:
- if data.get("type") == "CogneeUserInteraction":
- assert required_fields_user_interaction.issubset(data.keys()), (
- f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
- )
-
- for field in required_fields_user_interaction:
- value = data[field]
- assert isinstance(value, str) and value.strip(), (
- f"Node {node_id} has invalid value for '{field}': {value!r}"
- )
-
- if data.get("type") == "CogneeUserFeedback":
- assert required_fields_feedback.issubset(data.keys()), (
- f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
- )
-
- for field in required_fields_feedback:
- value = data[field]
- assert isinstance(value, str) and value.strip(), (
- f"Node {node_id} has invalid value for '{field}': {value!r}"
- )
-
- await cognee.prune.prune_data()
- await cognee.prune.prune_system(metadata=True)
-
- await cognee.add(text_1, dataset_name)
-
- await cognee.add([text], dataset_name)
-
- await cognee.cognify([dataset_name])
+@pytest_asyncio.fixture(scope="session")
+async def feedback_state():
+ """Feedback-weight scenario computed once (fresh environment)."""
+ await setup_test_environment_for_feedback()
await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
-
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This was the best answer I've ever seen",
last_k=1,
)
-
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="Wow the correctness of this answer blows my mind",
last_k=1,
)
+ graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
+ return {"graph_snapshot": graph}
- edges = graph[1]
- for from_node, to_node, relationship_name, properties in edges:
+@pytest.mark.asyncio
+async def test_e2e_graph_vector_consistency(e2e_state):
+ """Graph and vector stores contain the same triplet edges."""
+ assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
+
+
+@pytest.mark.asyncio
+async def test_e2e_retriever_contexts(e2e_state):
+ """All retrievers return non-empty, well-typed contexts."""
+ contexts = e2e_state["contexts"]
+
+ for name in [
+ "graph_completion",
+ "graph_completion_cot",
+ "graph_completion_context_extension",
+ "graph_summary_completion",
+ ]:
+ ctx = contexts[name]
+ assert isinstance(ctx, list), f"{name}: Context should be a list"
+ assert ctx, f"{name}: Context should not be empty"
+ ctx_text = await resolve_edges_to_text(ctx)
+ lower = ctx_text.lower()
+ assert "germany" in lower or "netherlands" in lower, (
+ f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
+ )
+
+ triplet_ctx = contexts["triplet"]
+ assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
+ assert triplet_ctx.strip(), "triplet: Context should not be empty"
+
+ chunks_ctx = contexts["chunks"]
+ assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
+ assert chunks_ctx, "chunks: Context should not be empty"
+ chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
+ assert "germany" in chunks_text or "netherlands" in chunks_text
+
+ summaries_ctx = contexts["summaries"]
+ assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
+ assert summaries_ctx, "summaries: Context should not be empty"
+ assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
+
+ rag_ctx = contexts["rag_completion"]
+ assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
+ assert rag_ctx.strip(), "rag_completion: Context should not be empty"
+
+ temporal_ctx = contexts["temporal"]
+ assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
+ assert temporal_ctx.strip(), "temporal: Context should not be empty"
+
+
+@pytest.mark.asyncio
+async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
+ """Graph retriever triplets include sane vector_distance metadata."""
+ for name, triplets in e2e_state["triplets"].items():
+ assert isinstance(triplets, list), f"{name}: Triplets should be a list"
+ assert triplets, f"{name}: Triplets list should not be empty"
+ for edge in triplets:
+ assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
+ distance = edge.attributes.get("vector_distance")
+ node1_distance = edge.node1.attributes.get("vector_distance")
+ node2_distance = edge.node2.attributes.get("vector_distance")
+ assert isinstance(distance, float), f"{name}: vector_distance should be float"
+ assert 0 <= distance <= 1
+ assert 0 <= node1_distance <= 1
+ assert 0 <= node2_distance <= 1
+
+
+@pytest.mark.asyncio
+async def test_e2e_search_results_and_wrappers(e2e_state):
+ """Search returns expected shapes across search types and access modes."""
+ from cognee.context_global_variables import backend_access_control_enabled
+
+ sr = e2e_state["search_results"]
+
+ # Completion-like search types: validate wrapper + content
+ for name in [
+ "graph_completion",
+ "graph_completion_cot",
+ "graph_completion_context_extension",
+ "graph_summary_completion",
+ "triplet_completion",
+ "rag_completion",
+ "temporal",
+ ]:
+ search_results = sr[name]
+ assert isinstance(search_results, list), f"{name}: should return a list"
+ assert len(search_results) == 1, f"{name}: expected single-element list"
+
+ if backend_access_control_enabled():
+ wrapper = search_results[0]
+ assert isinstance(wrapper, dict), (
+ f"{name}: expected wrapper dict in access control mode"
+ )
+ assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
+ assert wrapper.get("dataset_name") == "test_dataset"
+ assert "graphs" in wrapper
+ text = wrapper["search_result"][0]
+ else:
+ text = search_results[0]
+
+ assert isinstance(text, str) and text.strip()
+ assert "netherlands" in text.lower()
+
+ # Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
+ for name in ["chunks", "summaries"]:
+ search_results = sr[name]
+ assert isinstance(search_results, list), f"{name}: should return a list"
+ assert search_results, f"{name}: should not be empty"
+
+ first = search_results[0]
+ assert isinstance(first, dict), f"{name}: expected dict entries"
+
+ payloads = search_results
+ if "search_result" in first and "text" not in first:
+ payloads = (first.get("search_result") or [None])[0]
+
+ assert isinstance(payloads, list) and payloads
+ assert isinstance(payloads[0], dict)
+ assert str(payloads[0].get("text", "")).strip()
+
+
+@pytest.mark.asyncio
+async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
+ """Search interactions create expected graph nodes/edges and required fields."""
+ graph = e2e_state["graph_snapshot"]
+ nodes, edges = graph
+
+ type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
+ edge_type_counts = Counter(edge_type[2] for edge_type in edges)
+
+ assert type_counts.get("CogneeUserInteraction", 0) == 4
+ assert type_counts.get("CogneeUserFeedback", 0) == 2
+ assert type_counts.get("NodeSet", 0) == 2
+ assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
+ assert edge_type_counts.get("gives_feedback_to", 0) == 2
+ assert edge_type_counts.get("belongs_to_set", 0) >= 6
+
+ required_fields_user_interaction = {"question", "answer", "context"}
+ required_fields_feedback = {"feedback", "sentiment"}
+
+ for node_id, data in nodes:
+ if data.get("type") == "CogneeUserInteraction":
+ assert required_fields_user_interaction.issubset(data.keys())
+ for field in required_fields_user_interaction:
+ value = data[field]
+ assert isinstance(value, str) and value.strip()
+
+ if data.get("type") == "CogneeUserFeedback":
+ assert required_fields_feedback.issubset(data.keys())
+ for field in required_fields_feedback:
+ value = data[field]
+ assert isinstance(value, str) and value.strip()
+
+
+@pytest.mark.asyncio
+async def test_e2e_feedback_weight_calculation(feedback_state):
+ """Positive feedback increases used_graph_element_to_answer feedback_weight."""
+ _nodes, edges = feedback_state["graph_snapshot"]
+ for _from_node, _to_node, relationship_name, properties in edges:
if relationship_name == "used_graph_element_to_answer":
assert properties["feedback_weight"] >= 6, (
"Feedback weight calculation is not correct, it should be more then 6."
)
-
-
-if __name__ == "__main__":
- import asyncio
-
- asyncio.run(main())
diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py
index af3a4d90e..e072ceda8 100644
--- a/cognee/tests/unit/api/test_ontology_endpoint.py
+++ b/cognee/tests/unit/api/test_ontology_endpoint.py
@@ -1,17 +1,28 @@
import pytest
import uuid
from fastapi.testclient import TestClient
-from unittest.mock import patch, Mock, AsyncMock
+from unittest.mock import Mock
from types import SimpleNamespace
-import importlib
from cognee.api.client import app
+from cognee.modules.users.methods import get_authenticated_user
-gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
+
+@pytest.fixture(scope="session")
+def test_client():
+ # Keep a single TestClient (and event loop) for the whole module.
+ # Re-creating TestClient repeatedly can break async DB connections (asyncpg loop mismatch).
+ with TestClient(app) as c:
+ yield c
@pytest.fixture
-def client():
- return TestClient(app)
+def client(test_client, mock_default_user):
+ async def override_get_authenticated_user():
+ return mock_default_user
+
+ app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
+ yield test_client
+ app.dependency_overrides.pop(get_authenticated_user, None)
@pytest.fixture
@@ -32,12 +43,8 @@ def mock_default_user():
)
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
+def test_upload_ontology_success(client):
"""Test successful ontology upload"""
- import json
-
- mock_get_default_user.return_value = mock_default_user
ontology_content = (
b""
)
@@ -46,7 +53,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
- data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
+ data={"ontology_key": unique_key, "description": "Test"},
)
assert response.status_code == 200
@@ -55,10 +62,8 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
assert "uploaded_at" in data["uploaded_ontologies"][0]
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
+def test_upload_ontology_invalid_file(client):
"""Test 400 response for non-.owl files"""
- mock_get_default_user.return_value = mock_default_user
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
response = client.post(
"/api/v1/ontologies",
@@ -68,14 +73,10 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul
assert response.status_code == 400
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
+def test_upload_ontology_missing_data(client):
"""Test 400 response for missing file or key"""
- import json
-
- mock_get_default_user.return_value = mock_default_user
# Missing file
- response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
+ response = client.post("/api/v1/ontologies", data={"ontology_key": "test"})
assert response.status_code == 400
# Missing key
@@ -85,34 +86,25 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul
assert response.status_code == 400
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
- """Test behavior when default user is provided (no explicit authentication)"""
- import json
-
+def test_upload_ontology_without_auth_header(client):
+ """Test behavior when no explicit authentication header is provided."""
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
- mock_get_default_user.return_value = mock_default_user
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", b"", "application/xml"))],
- data={"ontology_key": json.dumps([unique_key])},
+ data={"ontology_key": unique_key},
)
- # The current system provides a default user when no explicit authentication is given
- # This test verifies the system works with conditional authentication
assert response.status_code == 200
data = response.json()
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
assert "uploaded_at" in data["uploaded_ontologies"][0]
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
- """Test uploading multiple ontology files in single request"""
+def test_upload_multiple_ontologies_in_single_request_is_rejected(client):
+ """Uploading multiple ontology files in a single request should fail."""
import io
- mock_get_default_user.return_value = mock_default_user
- # Create mock files
file1_content = b""
file2_content = b""
@@ -120,45 +112,34 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
]
- data = {
- "ontology_key": '["vehicles", "manufacturers"]',
- "descriptions": '["Base vehicles", "Car manufacturers"]',
- }
+ data = {"ontology_key": "vehicles", "description": "Base vehicles"}
response = client.post("/api/v1/ontologies", files=files, data=data)
- assert response.status_code == 200
- result = response.json()
- assert "uploaded_ontologies" in result
- assert len(result["uploaded_ontologies"]) == 2
- assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
- assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
+ assert response.status_code == 400
+ assert "Only one ontology_file is allowed" in response.json()["error"]
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
- """Test that upload endpoint accepts array parameters"""
+def test_upload_endpoint_rejects_array_style_fields(client):
+ """Array-style form values should be rejected (no backwards compatibility)."""
import io
import json
- mock_get_default_user.return_value = mock_default_user
file_content = b""
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
data = {
"ontology_key": json.dumps(["single_key"]),
- "descriptions": json.dumps(["Single ontology"]),
+ "description": json.dumps(["Single ontology"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data)
- assert response.status_code == 200
- result = response.json()
- assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
+ assert response.status_code == 400
+ assert "ontology_key must be a string" in response.json()["error"]
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
+def test_cognify_with_multiple_ontologies(client):
"""Test cognify endpoint accepts multiple ontology keys"""
payload = {
"datasets": ["test_dataset"],
@@ -172,14 +153,11 @@ def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_de
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
- """Test complete workflow: upload multiple ontologies โ cognify with multiple keys"""
+def test_complete_multifile_workflow(client):
+ """Test workflow: upload ontologies one-by-one โ cognify with multiple keys"""
import io
- import json
- mock_get_default_user.return_value = mock_default_user
- # Step 1: Upload multiple ontologies
+ # Step 1: Upload two ontologies (one-by-one)
file1_content = b"""
@@ -192,17 +170,21 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
"""
- files = [
- ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
- ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
- ]
- data = {
- "ontology_key": json.dumps(["vehicles", "manufacturers"]),
- "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
- }
+ upload_response_1 = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml"))],
+ data={"ontology_key": "vehicles", "description": "Vehicle ontology"},
+ )
+ assert upload_response_1.status_code == 200
- upload_response = client.post("/api/v1/ontologies", files=files, data=data)
- assert upload_response.status_code == 200
+ upload_response_2 = client.post(
+ "/api/v1/ontologies",
+ files=[
+ ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml"))
+ ],
+ data={"ontology_key": "manufacturers", "description": "Manufacturer ontology"},
+ )
+ assert upload_response_2.status_code == 200
# Step 2: Verify ontologies are listed
list_response = client.get("/api/v1/ontologies")
@@ -223,44 +205,42 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
assert cognify_response.status_code != 400 # Not a validation error
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
- """Test error handling for invalid multifile uploads"""
+def test_upload_error_handling(client):
+ """Test error handling for invalid uploads (single-file endpoint)."""
import io
import json
- # Test mismatched array lengths
+ # Array-style key should be rejected
file_content = b""
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
data = {
- "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
- "descriptions": json.dumps(["desc1"]),
+ "ontology_key": json.dumps(["key1", "key2"]),
+ "description": "desc1",
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400
- assert "Number of keys must match number of files" in response.json()["error"]
+ assert "ontology_key must be a string" in response.json()["error"]
- # Test duplicate keys
- files = [
- ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
- ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
- ]
- data = {
- "ontology_key": json.dumps(["duplicate", "duplicate"]),
- "descriptions": json.dumps(["desc1", "desc2"]),
- }
+ # Duplicate key should be rejected
+ response_1 = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml"))],
+ data={"ontology_key": "duplicate", "description": "desc1"},
+ )
+ assert response_1.status_code == 200
- response = client.post("/api/v1/ontologies", files=files, data=data)
- assert response.status_code == 400
- assert "Duplicate ontology keys not allowed" in response.json()["error"]
+ response_2 = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml"))],
+ data={"ontology_key": "duplicate", "description": "desc2"},
+ )
+ assert response_2.status_code == 400
+ assert "already exists" in response_2.json()["error"]
-@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
-def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
+def test_cognify_missing_ontology_key(client):
"""Test cognify with non-existent ontology key"""
- mock_get_default_user.return_value = mock_default_user
-
payload = {
"datasets": ["test_dataset"],
"ontology_key": ["nonexistent_key"],
diff --git a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py
index 70ec43cf8..b18012594 100644
--- a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py
+++ b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py
@@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
"""
+MOCK_HOTPOT_CORPUS = [
+ {
+ "_id": "1",
+ "question": "Next to which country is Germany located?",
+ "answer": "Netherlands",
+ # HotpotQA uses "level"; TwoWikiMultiHop uses "type".
+ "level": "easy",
+ "type": "comparison",
+ "context": [
+ ["Germany", ["Germany is in Europe."]],
+ ["Netherlands", ["The Netherlands borders Germany."]],
+ ],
+ "supporting_facts": [["Netherlands", 0]],
+ }
+]
+
ADAPTER_CLASSES = [
HotpotQAAdapter,
@@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
adapter = AdapterClass()
result = adapter.load_corpus()
+ elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
+ with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
+ adapter = AdapterClass()
+ result = adapter.load_corpus()
+
else:
adapter = AdapterClass()
result = adapter.load_corpus()
@@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
+ elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
+ with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
+ adapter = AdapterClass()
+ corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
else:
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
diff --git a/cognee/tests/unit/eval_framework/corpus_builder_test.py b/cognee/tests/unit/eval_framework/corpus_builder_test.py
index 14136bea5..53f886b58 100644
--- a/cognee/tests/unit/eval_framework/corpus_builder_test.py
+++ b/cognee/tests/unit/eval_framework/corpus_builder_test.py
@@ -2,15 +2,38 @@ import pytest
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
from cognee.infrastructure.databases.graph import get_graph_engine
from unittest.mock import AsyncMock, patch
+from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
+MOCK_HOTPOT_CORPUS = [
+ {
+ "_id": "1",
+ "question": "Next to which country is Germany located?",
+ "answer": "Netherlands",
+ # HotpotQA uses "level"; TwoWikiMultiHop uses "type".
+ "level": "easy",
+ "type": "comparison",
+ "context": [
+ ["Germany", ["Germany is in Europe."]],
+ ["Netherlands", ["The Netherlands borders Germany."]],
+ ],
+ "supporting_facts": [["Netherlands", 0]],
+ }
+]
+
@pytest.mark.parametrize("benchmark", benchmark_options)
def test_corpus_builder_load_corpus(benchmark):
limit = 2
- corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
- raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
+ if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
+ with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
+ raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
+ else:
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
+ raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
+
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2
- corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
- questions = await corpus_builder.build_corpus(limit=limit)
+ if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
+ with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
+ questions = await corpus_builder.build_corpus(limit=limit)
+ else:
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
+ questions = await corpus_builder.build_corpus(limit=limit)
+
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
)
diff --git a/pyproject.toml b/pyproject.toml
index 8e4ed8a0d..cf2081d0a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "cognee"
-version = "0.5.0.dev0"
+version = "0.5.0.dev1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
diff --git a/uv.lock b/uv.lock
index fccab8c40..884fb63be 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.10, <3.14"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
@@ -946,7 +946,7 @@ wheels = [
[[package]]
name = "cognee"
-version = "0.5.0.dev0"
+version = "0.5.0.dev1"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },