Merge branch 'dev' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-3
This commit is contained in:
commit
646894d7c5
53 changed files with 3033 additions and 601 deletions
25
.github/workflows/e2e_tests.yml
vendored
25
.github/workflows/e2e_tests.yml
vendored
|
|
@ -237,6 +237,31 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||||
|
|
||||||
|
test-dataset-database-deletion:
|
||||||
|
name: Test dataset database deletion in Cognee
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run dataset databases deletion test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_dataset_delete.py
|
||||||
|
|
||||||
test-permissions:
|
test-permissions:
|
||||||
name: Test permissions with different situations in Cognee
|
name: Test permissions with different situations in Cognee
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
|
||||||
154
.github/workflows/release.yml
vendored
Normal file
154
.github/workflows/release.yml
vendored
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
name: release.yml
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
flavour:
|
||||||
|
required: true
|
||||||
|
default: dev
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- dev
|
||||||
|
- main
|
||||||
|
description: Dev or Main release
|
||||||
|
test_mode:
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
description: Aka Dry Run. If true, it won't affect public indices or repositories
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release-github:
|
||||||
|
name: Create GitHub Release from ${{ inputs.flavour }}
|
||||||
|
outputs:
|
||||||
|
tag: ${{ steps.create_tag.outputs.tag }}
|
||||||
|
version: ${{ steps.create_tag.outputs.version }}
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out ${{ inputs.flavour }}
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.flavour }}
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Create and push git tag
|
||||||
|
id: create_tag
|
||||||
|
env:
|
||||||
|
TEST_MODE: ${{ inputs.test_mode }}
|
||||||
|
run: |
|
||||||
|
VERSION="$(uv version --short)"
|
||||||
|
TAG="v${VERSION}"
|
||||||
|
|
||||||
|
echo "Tag to create: ${TAG}"
|
||||||
|
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
|
||||||
|
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
if [ "$TEST_MODE" = "false" ]; then
|
||||||
|
git tag "${TAG}"
|
||||||
|
git push origin "${TAG}"
|
||||||
|
else
|
||||||
|
echo "Test mode is enabled. Skipping tag creation and push."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create GitHub Release
|
||||||
|
uses: softprops/action-gh-release@v2
|
||||||
|
with:
|
||||||
|
tag_name: ${{ steps.create_tag.outputs.tag }}
|
||||||
|
generate_release_notes: true
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
release-pypi-package:
|
||||||
|
needs: release-github
|
||||||
|
name: Release PyPI Package from ${{ inputs.flavour }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out ${{ inputs.flavour }}
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.flavour }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Install Python
|
||||||
|
run: uv python install
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --locked --all-extras
|
||||||
|
|
||||||
|
- name: Build distributions
|
||||||
|
run: uv build
|
||||||
|
|
||||||
|
- name: Publish ${{ inputs.flavour }} release to TestPyPI
|
||||||
|
if: ${{ inputs.test_mode }}
|
||||||
|
env:
|
||||||
|
UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }}
|
||||||
|
run: uv publish --publish-url https://test.pypi.org/legacy/
|
||||||
|
|
||||||
|
- name: Publish ${{ inputs.flavour }} release to PyPI
|
||||||
|
if: ${{ !inputs.test_mode }}
|
||||||
|
env:
|
||||||
|
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||||
|
run: uv publish
|
||||||
|
|
||||||
|
release-docker-image:
|
||||||
|
needs: release-github
|
||||||
|
name: Release Docker Image from ${{ inputs.flavour }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out ${{ inputs.flavour }}
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.flavour }}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and push Dev Docker Image
|
||||||
|
if: ${{ inputs.flavour == 'dev' }}
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
|
push: ${{ !inputs.test_mode }}
|
||||||
|
tags: cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||||
|
labels: |
|
||||||
|
version=${{ needs.release-github.outputs.version }}
|
||||||
|
flavour=${{ inputs.flavour }}
|
||||||
|
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||||
|
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||||
|
|
||||||
|
- name: Build and push Main Docker Image
|
||||||
|
if: ${{ inputs.flavour == 'main' }}
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
|
push: ${{ !inputs.test_mode }}
|
||||||
|
tags: |
|
||||||
|
cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||||
|
cognee/cognee:latest
|
||||||
|
labels: |
|
||||||
|
version=${{ needs.release-github.outputs.version }}
|
||||||
|
flavour=${{ inputs.flavour }}
|
||||||
|
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||||
|
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||||
46
.github/workflows/search_db_tests.yml
vendored
46
.github/workflows/search_db_tests.yml
vendored
|
|
@ -11,12 +11,21 @@ on:
|
||||||
type: string
|
type: string
|
||||||
default: "all"
|
default: "all"
|
||||||
description: "Which vector databases to test (comma-separated list or 'all')"
|
description: "Which vector databases to test (comma-separated list or 'all')"
|
||||||
|
python-versions:
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: '["3.10", "3.11", "3.12", "3.13"]'
|
||||||
|
description: "Python versions to test (JSON array)"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run-kuzu-lance-sqlite-search-tests:
|
run-kuzu-lance-sqlite-search-tests:
|
||||||
name: Search test for Kuzu/LanceDB/Sqlite
|
name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -26,7 +35,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
@ -45,13 +54,16 @@ jobs:
|
||||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||||
VECTOR_DB_PROVIDER: 'lancedb'
|
VECTOR_DB_PROVIDER: 'lancedb'
|
||||||
DB_PROVIDER: 'sqlite'
|
DB_PROVIDER: 'sqlite'
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-neo4j-lance-sqlite-search-tests:
|
run-neo4j-lance-sqlite-search-tests:
|
||||||
name: Search test for Neo4j/LanceDB/Sqlite
|
name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -61,7 +73,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Setup Neo4j with GDS
|
- name: Setup Neo4j with GDS
|
||||||
uses: ./.github/actions/setup_neo4j
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
|
@ -88,12 +100,16 @@ jobs:
|
||||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-kuzu-pgvector-postgres-search-tests:
|
run-kuzu-pgvector-postgres-search-tests:
|
||||||
name: Search test for Kuzu/PGVector/Postgres
|
name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg17
|
image: pgvector/pgvector:pg17
|
||||||
|
|
@ -117,7 +133,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
|
|
@ -143,12 +159,16 @@ jobs:
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
||||||
run-neo4j-pgvector-postgres-search-tests:
|
run-neo4j-pgvector-postgres-search-tests:
|
||||||
name: Search test for Neo4j/PGVector/Postgres
|
name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }})
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
|
fail-fast: false
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg17
|
image: pgvector/pgvector:pg17
|
||||||
|
|
@ -172,7 +192,7 @@ jobs:
|
||||||
- name: Cognee Setup
|
- name: Cognee Setup
|
||||||
uses: ./.github/actions/cognee_setup
|
uses: ./.github/actions/cognee_setup
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
- name: Setup Neo4j with GDS
|
- name: Setup Neo4j with GDS
|
||||||
|
|
@ -205,4 +225,4 @@ jobs:
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO
|
||||||
|
|
|
||||||
90
.github/workflows/test_llms.yml
vendored
90
.github/workflows/test_llms.yml
vendored
|
|
@ -84,3 +84,93 @@ jobs:
|
||||||
EMBEDDING_DIMENSIONS: "3072"
|
EMBEDDING_DIMENSIONS: "3072"
|
||||||
EMBEDDING_MAX_TOKENS: "8191"
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
run: uv run python ./examples/python/simple_example.py
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-api-key:
|
||||||
|
name: Run Bedrock API Key Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Run Bedrock API Key Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-aws-credentials:
|
||||||
|
name: Run Bedrock AWS Credentials Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Run Bedrock AWS Credentials Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-aws-profile:
|
||||||
|
name: Run Bedrock AWS Profile Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Configure AWS Profile
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/.aws
|
||||||
|
cat > ~/.aws/credentials << EOF
|
||||||
|
[bedrock-test]
|
||||||
|
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
- name: Run Bedrock AWS Profile Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_PROFILE_NAME: "bedrock-test"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Test client for Cognee MCP Server functionality.
|
Test client for Cognee MCP Server functionality.
|
||||||
|
|
||||||
This script tests all the tools and functions available in the Cognee MCP server,
|
This script tests all the tools and functions available in the Cognee MCP server,
|
||||||
including cognify, codify, search, prune, status checks, and utility functions.
|
including cognify, search, prune, status checks, and utility functions.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Set your OpenAI API key first
|
# Set your OpenAI API key first
|
||||||
|
|
@ -23,6 +23,7 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from cognee.shared.logging_utils import setup_logging
|
from cognee.shared.logging_utils import setup_logging
|
||||||
|
from logging import ERROR, INFO
|
||||||
|
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
@ -35,7 +36,7 @@ from src.server import (
|
||||||
load_class,
|
load_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set timeout for cognify/codify to complete in
|
# Set timeout for cognify to complete in
|
||||||
TIMEOUT = 5 * 60 # 5 min in seconds
|
TIMEOUT = 5 * 60 # 5 min in seconds
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -151,12 +152,9 @@ DEBUG = True
|
||||||
|
|
||||||
expected_tools = {
|
expected_tools = {
|
||||||
"cognify",
|
"cognify",
|
||||||
"codify",
|
|
||||||
"search",
|
"search",
|
||||||
"prune",
|
"prune",
|
||||||
"cognify_status",
|
"cognify_status",
|
||||||
"codify_status",
|
|
||||||
"cognee_add_developer_rules",
|
|
||||||
"list_data",
|
"list_data",
|
||||||
"delete",
|
"delete",
|
||||||
}
|
}
|
||||||
|
|
@ -247,106 +245,6 @@ DEBUG = True
|
||||||
}
|
}
|
||||||
print(f"❌ {test_name} test failed: {e}")
|
print(f"❌ {test_name} test failed: {e}")
|
||||||
|
|
||||||
async def test_codify(self):
|
|
||||||
"""Test the codify functionality using MCP client."""
|
|
||||||
print("\n🧪 Testing codify functionality...")
|
|
||||||
try:
|
|
||||||
async with self.mcp_server_session() as session:
|
|
||||||
codify_result = await session.call_tool(
|
|
||||||
"codify", arguments={"repo_path": self.test_repo_dir}
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time() # mark the start
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait a moment
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
# Check if codify processing is finished
|
|
||||||
status_result = await session.call_tool("codify_status", arguments={})
|
|
||||||
if hasattr(status_result, "content") and status_result.content:
|
|
||||||
status_text = (
|
|
||||||
status_result.content[0].text
|
|
||||||
if status_result.content
|
|
||||||
else str(status_result)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
status_text = str(status_result)
|
|
||||||
|
|
||||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
|
||||||
break
|
|
||||||
elif time.time() - start > TIMEOUT:
|
|
||||||
raise TimeoutError("Codify did not complete in 5min")
|
|
||||||
except DatabaseNotCreatedError:
|
|
||||||
if time.time() - start > TIMEOUT:
|
|
||||||
raise TimeoutError("Database was not created in 5min")
|
|
||||||
|
|
||||||
self.test_results["codify"] = {
|
|
||||||
"status": "PASS",
|
|
||||||
"result": codify_result,
|
|
||||||
"message": "Codify executed successfully",
|
|
||||||
}
|
|
||||||
print("✅ Codify test passed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.test_results["codify"] = {
|
|
||||||
"status": "FAIL",
|
|
||||||
"error": str(e),
|
|
||||||
"message": "Codify test failed",
|
|
||||||
}
|
|
||||||
print(f"❌ Codify test failed: {e}")
|
|
||||||
|
|
||||||
async def test_cognee_add_developer_rules(self):
|
|
||||||
"""Test the cognee_add_developer_rules functionality using MCP client."""
|
|
||||||
print("\n🧪 Testing cognee_add_developer_rules functionality...")
|
|
||||||
try:
|
|
||||||
async with self.mcp_server_session() as session:
|
|
||||||
result = await session.call_tool(
|
|
||||||
"cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time() # mark the start
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait a moment
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
# Check if developer rule cognify processing is finished
|
|
||||||
status_result = await session.call_tool("cognify_status", arguments={})
|
|
||||||
if hasattr(status_result, "content") and status_result.content:
|
|
||||||
status_text = (
|
|
||||||
status_result.content[0].text
|
|
||||||
if status_result.content
|
|
||||||
else str(status_result)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
status_text = str(status_result)
|
|
||||||
|
|
||||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
|
||||||
break
|
|
||||||
elif time.time() - start > TIMEOUT:
|
|
||||||
raise TimeoutError(
|
|
||||||
"Cognify of developer rules did not complete in 5min"
|
|
||||||
)
|
|
||||||
except DatabaseNotCreatedError:
|
|
||||||
if time.time() - start > TIMEOUT:
|
|
||||||
raise TimeoutError("Database was not created in 5min")
|
|
||||||
|
|
||||||
self.test_results["cognee_add_developer_rules"] = {
|
|
||||||
"status": "PASS",
|
|
||||||
"result": result,
|
|
||||||
"message": "Developer rules addition executed successfully",
|
|
||||||
}
|
|
||||||
print("✅ Developer rules test passed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.test_results["cognee_add_developer_rules"] = {
|
|
||||||
"status": "FAIL",
|
|
||||||
"error": str(e),
|
|
||||||
"message": "Developer rules test failed",
|
|
||||||
}
|
|
||||||
print(f"❌ Developer rules test failed: {e}")
|
|
||||||
|
|
||||||
async def test_search_functionality(self):
|
async def test_search_functionality(self):
|
||||||
"""Test the search functionality with different search types using MCP client."""
|
"""Test the search functionality with different search types using MCP client."""
|
||||||
print("\n🧪 Testing search functionality...")
|
print("\n🧪 Testing search functionality...")
|
||||||
|
|
@ -359,7 +257,11 @@ DEBUG = True
|
||||||
# Go through all Cognee search types
|
# Go through all Cognee search types
|
||||||
for search_type in SearchType:
|
for search_type in SearchType:
|
||||||
# Don't test these search types
|
# Don't test these search types
|
||||||
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
|
if search_type in [
|
||||||
|
SearchType.NATURAL_LANGUAGE,
|
||||||
|
SearchType.CYPHER,
|
||||||
|
SearchType.TRIPLET_COMPLETION,
|
||||||
|
]:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
async with self.mcp_server_session() as session:
|
async with self.mcp_server_session() as session:
|
||||||
|
|
@ -681,9 +583,6 @@ class TestModel:
|
||||||
test_name="Cognify2",
|
test_name="Cognify2",
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.test_codify()
|
|
||||||
await self.test_cognee_add_developer_rules()
|
|
||||||
|
|
||||||
# Test list_data and delete functionality
|
# Test list_data and delete functionality
|
||||||
await self.test_list_data()
|
await self.test_list_data()
|
||||||
await self.test_delete()
|
await self.test_delete()
|
||||||
|
|
@ -739,7 +638,5 @@ async def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from logging import ERROR
|
|
||||||
|
|
||||||
logger = setup_logging(log_level=ERROR)
|
logger = setup_logging(log_level=ERROR)
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ async def add(
|
||||||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||||
|
|
||||||
Optional:
|
Optional:
|
||||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
|
||||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||||
- DEFAULT_USER_EMAIL: Custom default user email
|
- DEFAULT_USER_EMAIL: Custom default user email
|
||||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ async def cognify(
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
temporal_cognify: bool = False,
|
temporal_cognify: bool = False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Transform ingested data into a structured knowledge graph.
|
Transform ingested data into a structured knowledge graph.
|
||||||
|
|
@ -223,6 +224,7 @@ async def cognify(
|
||||||
config=config,
|
config=config,
|
||||||
custom_prompt=custom_prompt,
|
custom_prompt=custom_prompt,
|
||||||
chunks_per_batch=chunks_per_batch,
|
chunks_per_batch=chunks_per_batch,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||||
|
|
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
config: Config = None,
|
config: Config = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
chunks_per_batch: int = 100,
|
chunks_per_batch: int = 100,
|
||||||
|
**kwargs,
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
if config is None:
|
if config is None:
|
||||||
ontology_config = get_ontology_env_config()
|
ontology_config = get_ontology_env_config()
|
||||||
|
|
@ -288,6 +291,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
config=config,
|
config=config,
|
||||||
custom_prompt=custom_prompt,
|
custom_prompt=custom_prompt,
|
||||||
task_config={"batch_size": chunks_per_batch},
|
task_config={"batch_size": chunks_per_batch},
|
||||||
|
**kwargs,
|
||||||
), # Generate knowledge graphs from the document chunks.
|
), # Generate knowledge graphs from the document chunks.
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO):
|
||||||
default="", description="Custom prompt for entity extraction and graph generation"
|
default="", description="Custom prompt for entity extraction and graph generation"
|
||||||
)
|
)
|
||||||
ontology_key: Optional[List[str]] = Field(
|
ontology_key: Optional[List[str]] = Field(
|
||||||
default=None, description="Reference to one or more previously uploaded ontologies"
|
default=None,
|
||||||
|
examples=[[]],
|
||||||
|
description="Reference to one or more previously uploaded ontologies",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
from cognee.modules.data.methods import delete_dataset
|
||||||
|
|
||||||
dataset = await get_dataset(user.id, dataset_id)
|
dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
|
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
|
||||||
|
|
||||||
await delete_dataset(dataset)
|
await delete_dataset(dataset[0])
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{dataset_id}/data/{data_id}",
|
"/{dataset_id}/data/{data_id}",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
|
from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter:
|
||||||
|
|
||||||
@router.post("", response_model=dict)
|
@router.post("", response_model=dict)
|
||||||
async def upload_ontology(
|
async def upload_ontology(
|
||||||
|
request: Request,
|
||||||
ontology_key: str = Form(...),
|
ontology_key: str = Form(...),
|
||||||
ontology_file: List[UploadFile] = File(...),
|
ontology_file: UploadFile = File(...),
|
||||||
descriptions: Optional[str] = Form(None),
|
description: Optional[str] = Form(None),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Upload ontology files with their respective keys for later use in cognify operations.
|
Upload a single ontology file for later use in cognify operations.
|
||||||
|
|
||||||
Supports both single and multiple file uploads:
|
|
||||||
- Single file: ontology_key=["key"], ontology_file=[file]
|
|
||||||
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
|
|
||||||
|
|
||||||
## Request Parameters
|
## Request Parameters
|
||||||
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
|
- **ontology_key** (str): User-defined identifier for the ontology.
|
||||||
- **ontology_file** (List[UploadFile]): OWL format ontology files
|
- **ontology_file** (UploadFile): Single OWL format ontology file
|
||||||
- **descriptions** (Optional[str]): JSON array string of optional descriptions
|
- **description** (Optional[str]): Optional description for the ontology.
|
||||||
|
|
||||||
## Response
|
## Response
|
||||||
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
|
Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
|
- **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
|
||||||
- **500 Internal Server Error**: File system or processing errors
|
- **500 Internal Server Error**: File system or processing errors
|
||||||
"""
|
"""
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
|
|
@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json
|
# Enforce: exactly one uploaded file for "ontology_file"
|
||||||
|
form = await request.form()
|
||||||
|
uploaded_files = form.getlist("ontology_file")
|
||||||
|
if len(uploaded_files) != 1:
|
||||||
|
raise ValueError("Only one ontology_file is allowed")
|
||||||
|
|
||||||
ontology_keys = json.loads(ontology_key)
|
if ontology_key.strip().startswith(("[", "{")):
|
||||||
description_list = json.loads(descriptions) if descriptions else None
|
raise ValueError("ontology_key must be a string")
|
||||||
|
if description is not None and description.strip().startswith(("[", "{")):
|
||||||
|
raise ValueError("description must be a string")
|
||||||
|
|
||||||
if not isinstance(ontology_keys, list):
|
result = await ontology_service.upload_ontology(
|
||||||
raise ValueError("ontology_key must be a JSON array")
|
ontology_key=ontology_key,
|
||||||
|
file=ontology_file,
|
||||||
results = await ontology_service.upload_ontologies(
|
user=user,
|
||||||
ontology_keys, ontology_file, user, description_list
|
description=description,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter:
|
||||||
"uploaded_at": result.uploaded_at,
|
"uploaded_at": result.uploaded_at,
|
||||||
"description": result.description,
|
"description": result.description,
|
||||||
}
|
}
|
||||||
for result in results
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except ValueError as e:
|
||||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,4 @@
|
||||||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||||
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||||
|
from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
|
||||||
|
from .get_vector_dataset_database_handler import get_vector_dataset_database_handler
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
||||||
|
return handler
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
||||||
|
return handler
|
||||||
|
|
@ -1,24 +1,12 @@
|
||||||
|
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
|
||||||
|
get_graph_dataset_database_handler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
|
||||||
|
get_vector_dataset_database_handler,
|
||||||
|
)
|
||||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
|
||||||
supported_dataset_database_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
|
||||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
|
||||||
supported_dataset_database_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
|
||||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_dataset_database_connection_info(
|
async def resolve_dataset_database_connection_info(
|
||||||
dataset_database: DatasetDatabase,
|
dataset_database: DatasetDatabase,
|
||||||
) -> DatasetDatabase:
|
) -> DatasetDatabase:
|
||||||
|
|
@ -31,6 +19,12 @@ async def resolve_dataset_database_connection_info(
|
||||||
Returns:
|
Returns:
|
||||||
DatasetDatabase instance with resolved connection info
|
DatasetDatabase instance with resolved connection info
|
||||||
"""
|
"""
|
||||||
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
|
||||||
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
|
||||||
|
dataset_database = await vector_dataset_database_handler[
|
||||||
|
"handler_instance"
|
||||||
|
].resolve_dataset_connection_info(dataset_database)
|
||||||
|
dataset_database = await graph_dataset_database_handler[
|
||||||
|
"handler_instance"
|
||||||
|
].resolve_dataset_connection_info(dataset_database)
|
||||||
return dataset_database
|
return dataset_database
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ class S3Config(BaseSettings):
|
||||||
aws_access_key_id: Optional[str] = None
|
aws_access_key_id: Optional[str] = None
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
aws_session_token: Optional[str] = None
|
aws_session_token: Optional[str] = None
|
||||||
|
aws_profile_name: Optional[str] = None
|
||||||
|
aws_bedrock_runtime_endpoint: Optional[str] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class LLMGateway:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def acreate_structured_output(
|
def acreate_structured_output(
|
||||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> Coroutine:
|
) -> Coroutine:
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
if llm_config.structured_output_framework.upper() == "BAML":
|
if llm_config.structured_output_framework.upper() == "BAML":
|
||||||
|
|
@ -31,7 +31,10 @@ class LLMGateway:
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
llm_client = get_llm_client()
|
||||||
return llm_client.acreate_structured_output(
|
return llm_client.acreate_structured_output(
|
||||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
text_input=text_input,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
|
||||||
|
|
||||||
|
|
||||||
async def extract_content_graph(
|
async def extract_content_graph(
|
||||||
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
|
||||||
):
|
):
|
||||||
if custom_prompt:
|
if custom_prompt:
|
||||||
system_prompt = custom_prompt
|
system_prompt = custom_prompt
|
||||||
|
|
@ -30,7 +30,7 @@ async def extract_content_graph(
|
||||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||||
|
|
||||||
content_graph = await LLMGateway.acreate_structured_output(
|
content_graph = await LLMGateway.acreate_structured_output(
|
||||||
content, system_prompt, response_model
|
content, system_prompt, response_model, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return content_graph
|
return content_graph
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from a user query.
|
Generate a response from a user query.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Bedrock LLM adapter module."""
|
||||||
|
|
||||||
|
from .adapter import BedrockAdapter
|
||||||
|
|
||||||
|
__all__ = ["BedrockAdapter"]
|
||||||
|
|
@ -0,0 +1,153 @@
|
||||||
|
import litellm
|
||||||
|
import instructor
|
||||||
|
from typing import Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
|
from instructor.exceptions import InstructorRetryException
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
|
LLMInterface,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.llm.exceptions import (
|
||||||
|
ContentPolicyFilterError,
|
||||||
|
MissingSystemPromptPathError,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||||
|
rate_limit_async,
|
||||||
|
rate_limit_sync,
|
||||||
|
sleep_and_retry_async,
|
||||||
|
sleep_and_retry_sync,
|
||||||
|
)
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockAdapter(LLMInterface):
|
||||||
|
"""
|
||||||
|
Adapter for AWS Bedrock API with support for three authentication methods:
|
||||||
|
1. API Key (Bearer Token)
|
||||||
|
2. AWS Credentials (access key + secret key)
|
||||||
|
3. AWS Profile (boto3 credential chain)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "Bedrock"
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
default_instructor_mode = "json_schema_mode"
|
||||||
|
|
||||||
|
MAX_RETRIES = 5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: str = None,
|
||||||
|
max_completion_tokens: int = 16384,
|
||||||
|
streaming: bool = False,
|
||||||
|
instructor_mode: str = None,
|
||||||
|
):
|
||||||
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
|
self.aclient = instructor.from_litellm(
|
||||||
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
|
)
|
||||||
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
|
def _create_bedrock_request(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> dict:
|
||||||
|
"""Create Bedrock request with authentication."""
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"custom_llm_provider": "bedrock",
|
||||||
|
"drop_params": True,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": text_input},
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
],
|
||||||
|
"response_model": response_model,
|
||||||
|
"max_retries": self.MAX_RETRIES,
|
||||||
|
"max_completion_tokens": self.max_completion_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
}
|
||||||
|
|
||||||
|
s3_config = get_s3_config()
|
||||||
|
|
||||||
|
# Add authentication parameters
|
||||||
|
if self.api_key:
|
||||||
|
request_params["api_key"] = self.api_key
|
||||||
|
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
|
||||||
|
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
|
||||||
|
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
|
||||||
|
if s3_config.aws_session_token:
|
||||||
|
request_params["aws_session_token"] = s3_config.aws_session_token
|
||||||
|
elif s3_config.aws_profile_name:
|
||||||
|
request_params["aws_profile_name"] = s3_config.aws_profile_name
|
||||||
|
|
||||||
|
if s3_config.aws_region:
|
||||||
|
request_params["aws_region_name"] = s3_config.aws_region
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if s3_config.aws_bedrock_runtime_endpoint:
|
||||||
|
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
|
||||||
|
|
||||||
|
return request_params
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
|
@sleep_and_retry_async()
|
||||||
|
@rate_limit_async
|
||||||
|
async def acreate_structured_output(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
"""Generate structured output from AWS Bedrock API."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||||
|
return await self.aclient.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
except (
|
||||||
|
ContentPolicyViolationError,
|
||||||
|
InstructorRetryException,
|
||||||
|
) as error:
|
||||||
|
if (
|
||||||
|
isinstance(error, InstructorRetryException)
|
||||||
|
and "content management policy" not in str(error).lower()
|
||||||
|
):
|
||||||
|
raise error
|
||||||
|
|
||||||
|
raise ContentPolicyFilterError(
|
||||||
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@observe
|
||||||
|
@sleep_and_retry_sync()
|
||||||
|
@rate_limit_sync
|
||||||
|
def create_structured_output(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
"""Generate structured output from AWS Bedrock API (synchronous)."""
|
||||||
|
|
||||||
|
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||||
|
return self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise MissingSystemPromptPathError()
|
||||||
|
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = (
|
||||||
|
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||||
|
if system_prompt
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return formatted_prompt
|
||||||
|
|
@ -80,7 +80,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from a user query.
|
Generate a response from a user query.
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from a user query.
|
Generate a response from a user query.
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class LLMProvider(Enum):
|
||||||
- CUSTOM: Represents a custom provider option.
|
- CUSTOM: Represents a custom provider option.
|
||||||
- GEMINI: Represents the Gemini provider.
|
- GEMINI: Represents the Gemini provider.
|
||||||
- MISTRAL: Represents the Mistral AI provider.
|
- MISTRAL: Represents the Mistral AI provider.
|
||||||
|
- BEDROCK: Represents the AWS Bedrock provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
@ -32,6 +33,7 @@ class LLMProvider(Enum):
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
|
BEDROCK = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(raise_api_key_error: bool = True):
|
def get_llm_client(raise_api_key_error: bool = True):
|
||||||
|
|
@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.MISTRAL:
|
elif provider == LLMProvider.MISTRAL:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None and raise_api_key_error:
|
||||||
raise LLMAPIKeyNotSetError()
|
raise LLMAPIKeyNotSetError()
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||||
|
|
@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif provider == LLMProvider.BEDROCK:
|
||||||
|
# if llm_config.llm_api_key is None and raise_api_key_error:
|
||||||
|
# raise LLMAPIKeyNotSetError()
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
|
||||||
|
BedrockAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BedrockAdapter(
|
||||||
|
model=llm_config.llm_model,
|
||||||
|
api_key=llm_config.llm_api_key,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
streaming=llm_config.llm_streaming,
|
||||||
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise UnsupportedLLMProviderError(provider)
|
raise UnsupportedLLMProviderError(provider)
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class MistralAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from the user query.
|
Generate a response from the user query.
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a structured output from the LLM using the provided text and system prompt.
|
Generate a structured output from the LLM using the provided text and system prompt.
|
||||||
|
|
@ -123,7 +123,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def create_transcript(self, input_file: str) -> str:
|
async def create_transcript(self, input_file: str, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -162,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def transcribe_image(self, input_file: str) -> str:
|
async def transcribe_image(self, input_file: str, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe content from an image using base64 encoding.
|
Transcribe content from an image using base64 encoding.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from a user query.
|
Generate a response from a user query.
|
||||||
|
|
@ -154,6 +154,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
|
|
@ -180,6 +181,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
# api_base=self.fallback_endpoint,
|
# api_base=self.fallback_endpoint,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
|
|
@ -205,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def create_structured_output(
|
def create_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a response from a user query.
|
Generate a response from a user query.
|
||||||
|
|
@ -245,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
|
|
@ -254,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def create_transcript(self, input):
|
async def create_transcript(self, input, **kwargs):
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -281,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_base=self.endpoint,
|
api_base=self.endpoint,
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return transcription
|
return transcription
|
||||||
|
|
@ -292,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def transcribe_image(self, input) -> BaseModel:
|
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a transcription of an image from a user query.
|
Generate a transcription of an image from a user query.
|
||||||
|
|
||||||
|
|
@ -337,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
max_completion_tokens=300,
|
max_completion_tokens=300,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,10 @@ from cognee.context_global_variables import backend_access_control_enabled
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.infrastructure.databases.utils import (
|
||||||
|
get_graph_dataset_database_handler,
|
||||||
|
get_vector_dataset_database_handler,
|
||||||
|
)
|
||||||
from cognee.shared.cache import delete_cache
|
from cognee.shared.cache import delete_cache
|
||||||
from cognee.modules.users.models import DatasetDatabase
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -13,22 +17,13 @@ logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def prune_graph_databases():
|
async def prune_graph_databases():
|
||||||
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
|
||||||
supported_dataset_database_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[
|
|
||||||
dataset_database.graph_dataset_database_handler
|
|
||||||
]
|
|
||||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
try:
|
try:
|
||||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
# Go through each dataset database and delete the graph database
|
# Go through each dataset database and delete the graph database
|
||||||
for data_item in data:
|
for dataset_database in dataset_databases:
|
||||||
await _prune_graph_db(data_item)
|
handler = get_graph_dataset_database_handler(dataset_database)
|
||||||
|
await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
except (OperationalError, EntityNotFoundError) as e:
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
||||||
|
|
@ -38,22 +33,13 @@ async def prune_graph_databases():
|
||||||
|
|
||||||
|
|
||||||
async def prune_vector_databases():
|
async def prune_vector_databases():
|
||||||
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
|
||||||
supported_dataset_database_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[
|
|
||||||
dataset_database.vector_dataset_database_handler
|
|
||||||
]
|
|
||||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
try:
|
try:
|
||||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
# Go through each dataset database and delete the vector database
|
# Go through each dataset database and delete the vector database
|
||||||
for data_item in data:
|
for dataset_database in dataset_databases:
|
||||||
await _prune_vector_db(data_item)
|
handler = get_vector_dataset_database_handler(dataset_database)
|
||||||
|
await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
except (OperationalError, EntityNotFoundError) as e:
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,34 @@
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from cognee.modules.data.models import Dataset
|
from cognee.modules.data.models import Dataset
|
||||||
|
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
|
||||||
|
get_vector_dataset_database_handler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
|
||||||
|
get_graph_dataset_database_handler,
|
||||||
|
)
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
|
||||||
async def delete_dataset(dataset: Dataset):
|
async def delete_dataset(dataset: Dataset):
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
stmt = select(DatasetDatabase).where(
|
||||||
|
DatasetDatabase.dataset_id == dataset.id,
|
||||||
|
)
|
||||||
|
dataset_database: DatasetDatabase = await session.scalar(stmt)
|
||||||
|
if dataset_database:
|
||||||
|
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
|
||||||
|
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
|
||||||
|
await graph_dataset_database_handler["handler_instance"].delete_dataset(
|
||||||
|
dataset_database
|
||||||
|
)
|
||||||
|
await vector_dataset_database_handler["handler_instance"].delete_dataset(
|
||||||
|
dataset_database
|
||||||
|
)
|
||||||
|
# TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
|
||||||
|
# This blocks recreation of the dataset with the same name and data after deletion as
|
||||||
|
# it's marked as completed and will be just skipped even though it's empty.
|
||||||
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
|
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
|
||||||
"""Initialize retriever with optional custom prompt paths."""
|
"""Initialize retriever with optional custom prompt paths."""
|
||||||
self.user_prompt_path = user_prompt_path
|
self.user_prompt_path = user_prompt_path
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
self.top_k = top_k if top_k is not None else 1
|
self.top_k = top_k if top_k is not None else 5
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
|
||||||
|
|
||||||
|
|
||||||
def format_triplets(edges):
|
def format_triplets(edges):
|
||||||
print("\n\n\n")
|
|
||||||
|
|
||||||
def filter_attributes(obj, attributes):
|
|
||||||
"""Helper function to filter out non-None properties, including nested dicts."""
|
|
||||||
result = {}
|
|
||||||
for attr in attributes:
|
|
||||||
value = getattr(obj, attr, None)
|
|
||||||
if value is not None:
|
|
||||||
# If the value is a dict, extract relevant keys from it
|
|
||||||
if isinstance(value, dict):
|
|
||||||
nested_values = {
|
|
||||||
k: v for k, v in value.items() if k in attributes and v is not None
|
|
||||||
}
|
|
||||||
result[attr] = nested_values
|
|
||||||
else:
|
|
||||||
result[attr] = value
|
|
||||||
return result
|
|
||||||
|
|
||||||
triplets = []
|
triplets = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
node1 = edge.node1
|
node1 = edge.node1
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ class ModelName(Enum):
|
||||||
anthropic = "anthropic"
|
anthropic = "anthropic"
|
||||||
gemini = "gemini"
|
gemini = "gemini"
|
||||||
mistral = "mistral"
|
mistral = "mistral"
|
||||||
|
bedrock = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
|
|
@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
|
||||||
"value": "mistral",
|
"value": "mistral",
|
||||||
"label": "Mistral",
|
"label": "Mistral",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"value": "bedrock",
|
||||||
|
"label": "Bedrock",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return SettingsDict.model_validate(
|
return SettingsDict.model_validate(
|
||||||
|
|
@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
|
||||||
"label": "Mistral Large 2.1",
|
"label": "Mistral Large 2.1",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"bedrock": [
|
||||||
|
{
|
||||||
|
"value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"label": "Claude 4.5 Sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"label": "Claude 4.5 Haiku",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "eu.amazon.nova-lite-v1:0",
|
||||||
|
"label": "Amazon Nova Lite",
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
vector_db={
|
vector_db={
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,7 @@ async def extract_graph_from_data(
|
||||||
graph_model: Type[BaseModel],
|
graph_model: Type[BaseModel],
|
||||||
config: Config = None,
|
config: Config = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||||
|
|
@ -111,7 +112,7 @@ async def extract_graph_from_data(
|
||||||
|
|
||||||
chunk_graphs = await asyncio.gather(
|
chunk_graphs = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
|
||||||
for chunk in data_chunks
|
for chunk in data_chunks
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -155,7 +156,12 @@ def _process_single_triplet(
|
||||||
|
|
||||||
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
||||||
|
|
||||||
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
|
relationship_name = relationship.get("relationship_name", "")
|
||||||
|
triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
|
||||||
|
|
||||||
|
triplet_obj = Triplet(
|
||||||
|
id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
|
||||||
|
)
|
||||||
|
|
||||||
return triplet_obj, None
|
return triplet_obj, None
|
||||||
|
|
||||||
|
|
|
||||||
252
cognee/tests/integration/retrieval/test_chunks_retriever.py
Normal file
252
cognee/tests/integration/retrieval/test_chunks_retriever.py
Normal file
|
|
@ -0,0 +1,252 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
from typing import List
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.data.processing.document_types import Document
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentChunkWithEntities(DataPoint):
|
||||||
|
text: str
|
||||||
|
chunk_size: int
|
||||||
|
chunk_index: int
|
||||||
|
cut_type: str
|
||||||
|
is_part_of: Document
|
||||||
|
contains: List[Entity] = None
|
||||||
|
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_chunks_simple():
|
||||||
|
"""Set up a clean test environment with simple chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_chunks_complex():
|
||||||
|
"""Set up a clean test environment with complex chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document1 = TextDocument(
|
||||||
|
name="Employee List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
document2 = TextDocument(
|
||||||
|
name="Car List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk4 = DocumentChunk(
|
||||||
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
|
||||||
|
"""Integration test: verify ChunksRetriever can retrieve multiple chunks."""
|
||||||
|
retriever = ChunksRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Steve")
|
||||||
|
|
||||||
|
assert isinstance(context, list), "Context should be a list"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
|
||||||
|
"Failed to get Steve Rodger chunk"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
|
||||||
|
"""Integration test: verify ChunksRetriever respects top_k parameter."""
|
||||||
|
retriever = ChunksRetriever(top_k=2)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Employee")
|
||||||
|
|
||||||
|
assert isinstance(context, list), "Context should be a list"
|
||||||
|
assert len(context) <= 2, "Should respect top_k limit"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
|
||||||
|
"""Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
|
||||||
|
retriever = ChunksRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify ChunksRetriever handles empty graph correctly."""
|
||||||
|
retriever = ChunksRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(NoDataError):
|
||||||
|
await retriever.get_context("Christina Mayer")
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
await vector_engine.create_collection(
|
||||||
|
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||||
|
)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina Mayer")
|
||||||
|
assert len(context) == 0, "Found chunks when none should exist"
|
||||||
|
|
@ -0,0 +1,268 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
from typing import Optional, Union
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_simple():
|
||||||
|
"""Set up a clean test environment with simple graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma", description="Figma is a company")
|
||||||
|
company2 = Company(name="Canva", description="Canvas is a company")
|
||||||
|
person1 = Person(
|
||||||
|
name="Steve Rodger",
|
||||||
|
description="This is description about Steve Rodger",
|
||||||
|
works_for=company1,
|
||||||
|
)
|
||||||
|
person2 = Person(
|
||||||
|
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||||
|
)
|
||||||
|
person3 = Person(
|
||||||
|
name="Jason Statham",
|
||||||
|
description="This is description about Jason Statham",
|
||||||
|
works_for=company1,
|
||||||
|
)
|
||||||
|
person4 = Person(
|
||||||
|
name="Mike Broski",
|
||||||
|
description="This is description about Mike Broski",
|
||||||
|
works_for=company2,
|
||||||
|
)
|
||||||
|
person5 = Person(
|
||||||
|
name="Christina Mayer",
|
||||||
|
description="This is description about Christina Mayer",
|
||||||
|
works_for=company2,
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_complex():
|
||||||
|
"""Set up a clean test environment with complex graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_context_simple(setup_test_environment_simple):
|
||||||
|
"""Integration test: verify GraphCompletionRetriever can retrieve context (simple)."""
|
||||||
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
|
# Ensure the top-level sections are present
|
||||||
|
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||||
|
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||||
|
|
||||||
|
# --- Nodes headers ---
|
||||||
|
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||||
|
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||||
|
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||||
|
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||||
|
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||||
|
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||||
|
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||||
|
|
||||||
|
# --- Node contents ---
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Steve Rodger altered"
|
||||||
|
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||||
|
"Description block for Figma altered"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Ike Loma altered"
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Jason Statham altered"
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Mike Broski altered"
|
||||||
|
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||||
|
"Description block for Canva altered"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Christina Mayer altered"
|
||||||
|
|
||||||
|
# --- Connections ---
|
||||||
|
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Steve Rodger→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Ike Loma→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Jason Statham→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||||
|
"Connection Mike Broski→Canva missing or changed"
|
||||||
|
)
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||||
|
"Connection Christina Mayer→Canva missing or changed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_context_complex(setup_test_environment_complex):
|
||||||
|
"""Integration test: verify GraphCompletionRetriever can retrieve context (complex)."""
|
||||||
|
retriever = GraphCompletionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionRetriever handles empty graph correctly."""
|
||||||
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_get_triplets_empty(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionRetriever get_triplets handles empty graph."""
|
||||||
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
|
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(triplets, list), "Triplets should be a list"
|
||||||
|
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
from typing import Optional, Union
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_simple():
|
||||||
|
"""Set up a clean test environment with simple graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_graph_completion_extension_context_simple"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_graph_completion_extension_context_simple"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_complex():
|
||||||
|
"""Set up a clean test environment with complex graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_graph_completion_extension_context_complex"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_graph_completion_extension_context_complex"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_simple(setup_test_environment_simple):
|
||||||
|
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple)."""
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_complex(setup_test_environment_complex):
|
||||||
|
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex)."""
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(
|
||||||
|
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly."""
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph."""
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(triplets, list), "Triplets should be a list"
|
||||||
|
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
from typing import Optional, Union
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_simple():
|
||||||
|
"""Set up a clean test environment with simple graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_graph_completion_cot_context_simple"
|
||||||
|
)
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_complex():
|
||||||
|
"""Set up a clean test environment with complex graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_graph_completion_cot_context_complex"
|
||||||
|
)
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without graph data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_cot_context_simple(setup_test_environment_simple):
|
||||||
|
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (simple)."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_cot_context_complex(setup_test_environment_complex):
|
||||||
|
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (complex)."""
|
||||||
|
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionCotRetriever handles empty graph correctly."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph."""
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(triplets, list), "Triplets should be a list"
|
||||||
|
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.data.processing.document_types import Document
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentChunkWithEntities(DataPoint):
|
||||||
|
text: str
|
||||||
|
chunk_size: int
|
||||||
|
chunk_index: int
|
||||||
|
cut_type: str
|
||||||
|
is_part_of: Document
|
||||||
|
contains: List[Entity] = None
|
||||||
|
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_chunks_simple():
|
||||||
|
"""Set up a clean test environment with simple chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_chunks_complex():
|
||||||
|
"""Set up a clean test environment with complex chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document1 = TextDocument(
|
||||||
|
name="Employee List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
document2 = TextDocument(
|
||||||
|
name="Car List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk4 = DocumentChunk(
|
||||||
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without chunks."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(
|
||||||
|
base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
data_directory_path = str(
|
||||||
|
base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph"
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple):
|
||||||
|
"""Integration test: verify CompletionRetriever can retrieve context (simple)."""
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Mike")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert "Mike Broski" in context, "Failed to get Mike Broski"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple):
|
||||||
|
"""Integration test: verify CompletionRetriever can retrieve context from multiple chunks."""
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Steve")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert "Steve Rodger" in context, "Failed to get Steve Rodger"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex):
|
||||||
|
"""Integration test: verify CompletionRetriever can retrieve context (complex)."""
|
||||||
|
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||||
|
retriever = CompletionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify CompletionRetriever handles empty graph correctly."""
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(NoDataError):
|
||||||
|
await retriever.get_context("Christina Mayer")
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
await vector_engine.create_collection(
|
||||||
|
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||||
|
)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina Mayer")
|
||||||
|
assert context == "", "Returned context should be empty on an empty graph"
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
|
||||||
import cognee
|
|
||||||
import pathlib
|
|
||||||
import os
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
import cognee
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
|
@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion():
|
||||||
_assert_structured_answer(structured_answer)
|
_assert_structured_answer(structured_answer)
|
||||||
|
|
||||||
|
|
||||||
class TestStructuredOutputCompletion:
|
@pytest_asyncio.fixture
|
||||||
@pytest.mark.asyncio
|
async def setup_test_environment():
|
||||||
async def test_get_structured_completion(self):
|
"""Set up a clean test environment with graph and document data."""
|
||||||
system_directory_path = os.path.join(
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion")
|
||||||
)
|
data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion")
|
||||||
cognee.config.system_root_directory(system_directory_path)
|
|
||||||
data_directory_path = os.path.join(
|
|
||||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
||||||
)
|
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
works_since: int
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
entity_type = EntityType(name="Person", description="A human individual")
|
||||||
|
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||||
|
|
||||||
|
entities = [entity]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await setup()
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
class Company(DataPoint):
|
|
||||||
name: str
|
|
||||||
|
|
||||||
class Person(DataPoint):
|
@pytest.mark.asyncio
|
||||||
name: str
|
async def test_get_structured_completion(setup_test_environment):
|
||||||
works_for: Company
|
"""Integration test: verify structured output completion for all retrievers."""
|
||||||
works_since: int
|
await _test_get_structured_graph_completion_cot()
|
||||||
|
await _test_get_structured_graph_completion()
|
||||||
company1 = Company(name="Figma")
|
await _test_get_structured_graph_completion_temporal()
|
||||||
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
await _test_get_structured_graph_completion_rag()
|
||||||
|
await _test_get_structured_graph_completion_context_extension()
|
||||||
entities = [company1, person1]
|
await _test_get_structured_entity_completion()
|
||||||
await add_data_points(entities)
|
|
||||||
|
|
||||||
document = TextDocument(
|
|
||||||
name="Steve Rodger's career",
|
|
||||||
raw_data_location="somewhere",
|
|
||||||
external_metadata="",
|
|
||||||
mime_type="text/plain",
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk1 = DocumentChunk(
|
|
||||||
text="Steve Rodger",
|
|
||||||
chunk_size=2,
|
|
||||||
chunk_index=0,
|
|
||||||
cut_type="sentence_end",
|
|
||||||
is_part_of=document,
|
|
||||||
contains=[],
|
|
||||||
)
|
|
||||||
chunk2 = DocumentChunk(
|
|
||||||
text="Mike Broski",
|
|
||||||
chunk_size=2,
|
|
||||||
chunk_index=1,
|
|
||||||
cut_type="sentence_end",
|
|
||||||
is_part_of=document,
|
|
||||||
contains=[],
|
|
||||||
)
|
|
||||||
chunk3 = DocumentChunk(
|
|
||||||
text="Christina Mayer",
|
|
||||||
chunk_size=2,
|
|
||||||
chunk_index=2,
|
|
||||||
cut_type="sentence_end",
|
|
||||||
is_part_of=document,
|
|
||||||
contains=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
entities = [chunk1, chunk2, chunk3]
|
|
||||||
await add_data_points(entities)
|
|
||||||
|
|
||||||
entity_type = EntityType(name="Person", description="A human individual")
|
|
||||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
|
||||||
|
|
||||||
entities = [entity]
|
|
||||||
await add_data_points(entities)
|
|
||||||
|
|
||||||
await _test_get_structured_graph_completion_cot()
|
|
||||||
await _test_get_structured_graph_completion()
|
|
||||||
await _test_get_structured_graph_completion_temporal()
|
|
||||||
await _test_get_structured_graph_completion_rag()
|
|
||||||
await _test_get_structured_graph_completion_context_extension()
|
|
||||||
await _test_get_structured_entity_completion()
|
|
||||||
184
cognee/tests/integration/retrieval/test_summaries_retriever.py
Normal file
184
cognee/tests/integration/retrieval/test_summaries_retriever.py
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.tasks.summarization.models import TextSummary
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_summaries():
|
||||||
|
"""Set up a clean test environment with summaries."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document1 = TextDocument(
|
||||||
|
name="Employee List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
document2 = TextDocument(
|
||||||
|
name="Car List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk1_summary = TextSummary(
|
||||||
|
text="S.R.",
|
||||||
|
made_from=chunk1,
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2_summary = TextSummary(
|
||||||
|
text="M.B.",
|
||||||
|
made_from=chunk2,
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3_summary = TextSummary(
|
||||||
|
text="C.M.",
|
||||||
|
made_from=chunk3,
|
||||||
|
)
|
||||||
|
chunk4 = DocumentChunk(
|
||||||
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk4_summary = TextSummary(
|
||||||
|
text="R.R.",
|
||||||
|
made_from=chunk4,
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5_summary = TextSummary(
|
||||||
|
text="H.Y.",
|
||||||
|
made_from=chunk5,
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6_summary = TextSummary(
|
||||||
|
text="C.H.",
|
||||||
|
made_from=chunk6,
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
chunk1_summary,
|
||||||
|
chunk2_summary,
|
||||||
|
chunk3_summary,
|
||||||
|
chunk4_summary,
|
||||||
|
chunk5_summary,
|
||||||
|
chunk6_summary,
|
||||||
|
]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without summaries."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summaries_retriever_context(setup_test_environment_with_summaries):
|
||||||
|
"""Integration test: verify SummariesRetriever can retrieve summary context."""
|
||||||
|
retriever = SummariesRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert isinstance(context, list), "Context should be a list"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify SummariesRetriever handles empty graph correctly."""
|
||||||
|
retriever = SummariesRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(NoDataError):
|
||||||
|
await retriever.get_context("Christina Mayer")
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina Mayer")
|
||||||
|
assert context == [], "Returned context should be empty on an empty graph"
|
||||||
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
|
|
@ -0,0 +1,306 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
import pytest_asyncio
|
||||||
|
import cognee
|
||||||
|
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
|
from cognee.modules.engine.models.Event import Event
|
||||||
|
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||||
|
from cognee.modules.engine.models.Interval import Interval
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_events():
|
||||||
|
"""Set up a clean test environment with temporal events."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
# Create timestamps for events
|
||||||
|
timestamp1 = Timestamp(
|
||||||
|
time_at=1609459200, # 2021-01-01 00:00:00
|
||||||
|
year=2021,
|
||||||
|
month=1,
|
||||||
|
day=1,
|
||||||
|
hour=0,
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
timestamp_str="2021-01-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp2 = Timestamp(
|
||||||
|
time_at=1612137600, # 2021-02-01 00:00:00
|
||||||
|
year=2021,
|
||||||
|
month=2,
|
||||||
|
day=1,
|
||||||
|
hour=0,
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
timestamp_str="2021-02-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp3 = Timestamp(
|
||||||
|
time_at=1614556800, # 2021-03-01 00:00:00
|
||||||
|
year=2021,
|
||||||
|
month=3,
|
||||||
|
day=1,
|
||||||
|
hour=0,
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
timestamp_str="2021-03-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp4 = Timestamp(
|
||||||
|
time_at=1625097600, # 2021-07-01 00:00:00
|
||||||
|
year=2021,
|
||||||
|
month=7,
|
||||||
|
day=1,
|
||||||
|
hour=0,
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
timestamp_str="2021-07-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp5 = Timestamp(
|
||||||
|
time_at=1633046400, # 2021-10-01 00:00:00
|
||||||
|
year=2021,
|
||||||
|
month=10,
|
||||||
|
day=1,
|
||||||
|
hour=0,
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
timestamp_str="2021-10-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create interval for event spanning multiple timestamps
|
||||||
|
interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
|
||||||
|
|
||||||
|
# Create events with timestamps
|
||||||
|
event1 = Event(
|
||||||
|
name="Project Alpha Launch",
|
||||||
|
description="Launched Project Alpha at the beginning of 2021",
|
||||||
|
at=timestamp1,
|
||||||
|
location="San Francisco",
|
||||||
|
)
|
||||||
|
|
||||||
|
event2 = Event(
|
||||||
|
name="Team Meeting",
|
||||||
|
description="Monthly team meeting discussing Q1 goals",
|
||||||
|
during=interval1,
|
||||||
|
location="New York",
|
||||||
|
)
|
||||||
|
|
||||||
|
event3 = Event(
|
||||||
|
name="Product Release",
|
||||||
|
description="Released new product features in July",
|
||||||
|
at=timestamp4,
|
||||||
|
location="Remote",
|
||||||
|
)
|
||||||
|
|
||||||
|
event4 = Event(
|
||||||
|
name="Company Retreat",
|
||||||
|
description="Annual company retreat in October",
|
||||||
|
at=timestamp5,
|
||||||
|
location="Lake Tahoe",
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [event1, event2, event3, event4]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_with_graph_data():
|
||||||
|
"""Set up a clean test environment with graph data (for fallback to triplets)."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma", description="Figma is a company")
|
||||||
|
person1 = Person(
|
||||||
|
name="Steve Rodger",
|
||||||
|
description="This is description about Steve Rodger",
|
||||||
|
works_for=company1,
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def setup_test_environment_empty():
|
||||||
|
"""Set up a clean test environment without data."""
|
||||||
|
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||||
|
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
|
||||||
|
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
|
||||||
|
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
|
||||||
|
"""Integration test: verify TemporalRetriever can retrieve events within time range."""
|
||||||
|
retriever = TemporalRetriever(top_k=5)
|
||||||
|
|
||||||
|
context = await retriever.get_context("What happened in January 2021?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
assert "Project Alpha" in context or "Launch" in context, (
|
||||||
|
"Should retrieve Project Alpha Launch event from January 2021"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
|
||||||
|
"""Integration test: verify TemporalRetriever can retrieve events at specific time."""
|
||||||
|
retriever = TemporalRetriever(top_k=5)
|
||||||
|
|
||||||
|
context = await retriever.get_context("What happened in July 2021?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
assert "Product Release" in context or "July" in context, (
|
||||||
|
"Should retrieve Product Release event from July 2021"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_context_fallback_to_triplets(
|
||||||
|
setup_test_environment_with_graph_data,
|
||||||
|
):
|
||||||
|
"""Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
|
||||||
|
retriever = TemporalRetriever(top_k=5)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
assert "Steve" in context or "Figma" in context, (
|
||||||
|
"Should retrieve graph data via triplet search fallback"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify TemporalRetriever handles empty graph correctly."""
|
||||||
|
retriever = TemporalRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("What happened?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) >= 0, "Context should be a string (possibly empty)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
|
||||||
|
"""Integration test: verify TemporalRetriever can generate completions."""
|
||||||
|
retriever = TemporalRetriever()
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("What happened in January 2021?")
|
||||||
|
|
||||||
|
assert isinstance(completion, list), "Completion should be a list"
|
||||||
|
assert len(completion) > 0, "Completion should not be empty"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||||
|
"Completion items should be non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
|
||||||
|
"""Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
|
||||||
|
retriever = TemporalRetriever()
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(completion, list), "Completion should be a list"
|
||||||
|
assert len(completion) > 0, "Completion should not be empty"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||||
|
"Completion items should be non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
|
||||||
|
"""Integration test: verify TemporalRetriever respects top_k parameter."""
|
||||||
|
retriever = TemporalRetriever(top_k=2)
|
||||||
|
|
||||||
|
context = await retriever.get_context("What happened in 2021?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
separator_count = context.count("#####################")
|
||||||
|
assert separator_count <= 1, "Should respect top_k limit of 2 events"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
|
||||||
|
"""Integration test: verify TemporalRetriever can retrieve multiple events."""
|
||||||
|
retriever = TemporalRetriever(top_k=10)
|
||||||
|
|
||||||
|
context = await retriever.get_context("What events occurred in 2021?")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"Project Alpha" in context
|
||||||
|
or "Team Meeting" in context
|
||||||
|
or "Product Release" in context
|
||||||
|
or "Company Retreat" in context
|
||||||
|
), "Should retrieve at least one event from 2021"
|
||||||
|
|
@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip
|
||||||
context = await retriever.get_context("Alice")
|
context = await retriever.get_context("Alice")
|
||||||
|
|
||||||
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
|
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
assert len(context) > 0, "Context should not be empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets):
|
||||||
|
"""Integration test: verify TripletRetriever can retrieve multiple triplets."""
|
||||||
|
retriever = TripletRetriever(top_k=5)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Bob")
|
||||||
|
|
||||||
|
assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, (
|
||||||
|
"Failed to get Bob-related triplets"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets):
|
||||||
|
"""Integration test: verify TripletRetriever respects top_k parameter."""
|
||||||
|
retriever = TripletRetriever(top_k=1)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Alice")
|
||||||
|
|
||||||
|
assert isinstance(context, str), "Context should be a string"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triplet_retriever_context_empty(setup_test_environment_empty):
|
||||||
|
"""Integration test: verify TripletRetriever handles empty graph correctly."""
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(NoDataError):
|
||||||
|
await retriever.get_context("Alice")
|
||||||
|
|
|
||||||
|
|
@ -148,8 +148,8 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
||||||
data={
|
data={
|
||||||
"ontology_key": json.dumps([ontology_key]),
|
"ontology_key": ontology_key,
|
||||||
"description": json.dumps(["Test ontology"]),
|
"description": "Test ontology",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(ontology_response.status_code, 200)
|
self.assertEqual(ontology_response.status_code, 200)
|
||||||
|
|
|
||||||
76
cognee/tests/test_dataset_delete.py
Normal file
76
cognee/tests/test_dataset_delete.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import pathlib
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||||
|
from cognee.modules.data.methods.delete_dataset import delete_dataset
|
||||||
|
from cognee.modules.data.methods.get_dataset import get_dataset
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Set data and system directory paths
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_dataset_delete")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_dataset_delete")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
print("Resetting cognee data...")
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
print("Data reset complete.\n")
|
||||||
|
|
||||||
|
# cognee knowledge graph will be created based on this text
|
||||||
|
text = """
|
||||||
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
|
subfield of computer science and information retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add the text, and make it available for cognify
|
||||||
|
await cognee.add(text, "nlp_dataset")
|
||||||
|
await cognee.add("Quantum computing is the study of quantum computers.", "quantum_dataset")
|
||||||
|
|
||||||
|
# Use LLMs and cognee to create knowledge graph
|
||||||
|
ret_val = await cognee.cognify()
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
for val in ret_val:
|
||||||
|
dataset_id = str(val)
|
||||||
|
vector_db_path = os.path.join(
|
||||||
|
cognee_directory_path, "databases", str(user.id), dataset_id + ".lance.db"
|
||||||
|
)
|
||||||
|
graph_db_path = os.path.join(
|
||||||
|
cognee_directory_path, "databases", str(user.id), dataset_id + ".pkl"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if databases are properly created and exist before deletion
|
||||||
|
assert os.path.exists(graph_db_path), "Graph database file not found."
|
||||||
|
assert os.path.exists(vector_db_path), "Vector database file not found."
|
||||||
|
|
||||||
|
dataset = await get_dataset(user_id=user.id, dataset_id=UUID(dataset_id))
|
||||||
|
await delete_dataset(dataset)
|
||||||
|
|
||||||
|
# Confirm databases have been deleted
|
||||||
|
assert not os.path.exists(graph_db_path), "Graph database file found."
|
||||||
|
assert not os.path.exists(vector_db_path), "Vector database file found."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = setup_logging(log_level=ERROR)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
import pathlib
|
import pathlib
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||||
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def _reset_engines_and_prune() -> None:
|
||||||
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
"""Reset db engine caches and prune data/system.
|
||||||
|
|
||||||
|
Kept intentionally identical to the inlined setup logic to avoid event loop issues when
|
||||||
|
using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
|
||||||
|
"""
|
||||||
|
# Dispose of existing engines and clear caches to ensure fresh instances for each test
|
||||||
|
try:
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
# Dispose SQLAlchemy engine connection pool if it exists
|
||||||
|
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
||||||
|
await vector_engine.engine.dispose(close=True)
|
||||||
|
except Exception:
|
||||||
|
# Engine might not exist yet
|
||||||
|
pass
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
|
create_relational_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_graph_engine.cache_clear()
|
||||||
|
create_vector_engine.cache_clear()
|
||||||
|
create_relational_engine.cache_clear()
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
dataset_name = "test_dataset"
|
|
||||||
|
|
||||||
|
async def _seed_default_dataset(dataset_name: str) -> dict:
|
||||||
|
"""Add the shared test dataset contents and run cognify (same steps/order as before)."""
|
||||||
text_1 = """Germany is located in europe right next to the Netherlands"""
|
text_1 = """Germany is located in europe right next to the Netherlands"""
|
||||||
|
|
||||||
|
logger.info(f"Adding text data to dataset: {dataset_name}")
|
||||||
await cognee.add(text_1, dataset_name)
|
await cognee.add(text_1, dataset_name)
|
||||||
|
|
||||||
explanation_file_path_quantum = os.path.join(
|
explanation_file_path_quantum = os.path.join(
|
||||||
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Adding file data to dataset: {dataset_name}")
|
||||||
await cognee.add([explanation_file_path_quantum], dataset_name)
|
await cognee.add([explanation_file_path_quantum], dataset_name)
|
||||||
|
|
||||||
|
logger.info(f"Running cognify on dataset: {dataset_name}")
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"explanation_file_path_quantum": explanation_file_path_quantum,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Use a single asyncio event loop for this test module.
|
||||||
|
|
||||||
|
This helps avoid "Future attached to a different loop" when running multiple async
|
||||||
|
tests that share clients/engines.
|
||||||
|
"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
yield loop
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_environment():
|
||||||
|
"""Helper function to set up test environment with data, cognify, and triplet embeddings."""
|
||||||
|
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
||||||
|
|
||||||
|
dataset_name = "test_dataset"
|
||||||
|
logger.info("Starting test setup: pruning data and system")
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
state = await _seed_default_dataset(dataset_name=dataset_name)
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||||
|
|
||||||
|
logger.info("Creating triplet embeddings")
|
||||||
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
||||||
|
|
||||||
|
# Check if Triplet_text collection was created
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
|
||||||
|
logger.info(f"Triplet_text collection exists after creation: {has_collection}")
|
||||||
|
|
||||||
|
if has_collection:
|
||||||
|
collection = await vector_engine.get_collection("Triplet_text")
|
||||||
|
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
|
||||||
|
logger.info(f"Triplet_text collection row count: {count}")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_environment_for_feedback():
|
||||||
|
"""Helper function to set up test environment for feedback weight calculation test."""
|
||||||
|
dataset_name = "test_dataset"
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
return await _seed_default_dataset(dataset_name=dataset_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def e2e_state():
|
||||||
|
"""Compute E2E artifacts once; tests only assert.
|
||||||
|
|
||||||
|
This avoids repeating expensive setup and LLM calls across multiple tests.
|
||||||
|
"""
|
||||||
|
await setup_test_environment()
|
||||||
|
|
||||||
|
# --- Graph/vector engine consistency ---
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
_nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
collection = await vector_engine.search(
|
collection = await vector_engine.search(
|
||||||
query_text="Test", limit=None, collection_name="Triplet_text"
|
collection_name="Triplet_text", query_text="Test", limit=None
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(edges) == len(collection), (
|
# --- Retriever contexts ---
|
||||||
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
query = "Next to which country is Germany located?"
|
||||||
)
|
|
||||||
|
|
||||||
context_gk = await GraphCompletionRetriever().get_context(
|
contexts = {
|
||||||
query="Next to which country is Germany located?"
|
"graph_completion": await GraphCompletionRetriever().get_context(query=query),
|
||||||
)
|
"graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
|
||||||
context_gk_cot = await GraphCompletionCotRetriever().get_context(
|
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query=query
|
||||||
)
|
),
|
||||||
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
|
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query=query
|
||||||
)
|
),
|
||||||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
"chunks": await ChunksRetriever(top_k=5).get_context(query=query),
|
||||||
query="Next to which country is Germany located?"
|
"summaries": await SummariesRetriever(top_k=5).get_context(query=query),
|
||||||
)
|
"rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
|
||||||
context_triplet = await TripletRetriever().get_context(
|
"temporal": await TemporalRetriever(top_k=5).get_context(query=query),
|
||||||
query="Next to which country is Germany located?"
|
"triplet": await TripletRetriever().get_context(query=query),
|
||||||
)
|
}
|
||||||
|
|
||||||
for name, context in [
|
# --- Retriever triplets + vector distance validation ---
|
||||||
("GraphCompletionRetriever", context_gk),
|
triplets = {
|
||||||
("GraphCompletionCotRetriever", context_gk_cot),
|
"graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
|
||||||
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
"graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
|
||||||
("GraphSummaryCompletionRetriever", context_gk_sum),
|
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
|
||||||
]:
|
query=query
|
||||||
assert isinstance(context, list), f"{name}: Context should be a list"
|
),
|
||||||
assert len(context) > 0, f"{name}: Context should not be empty"
|
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
|
||||||
|
query=query
|
||||||
context_text = await resolve_edges_to_text(context)
|
),
|
||||||
lower = context_text.lower()
|
}
|
||||||
assert "germany" in lower or "netherlands" in lower, (
|
|
||||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
|
||||||
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
|
||||||
lower_triplet = context_triplet.lower()
|
|
||||||
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
|
||||||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
|
|
||||||
query="Next to which country is Germany located?"
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, triplets in [
|
|
||||||
("GraphCompletionRetriever", triplets_gk),
|
|
||||||
("GraphCompletionCotRetriever", triplets_gk_cot),
|
|
||||||
("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
|
|
||||||
("GraphSummaryCompletionRetriever", triplets_gk_sum),
|
|
||||||
]:
|
|
||||||
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
|
||||||
assert triplets, f"{name}: Triplets list should not be empty"
|
|
||||||
for edge in triplets:
|
|
||||||
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
|
||||||
distance = edge.attributes.get("vector_distance")
|
|
||||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
|
||||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
|
||||||
assert isinstance(distance, float), (
|
|
||||||
f"{name}: vector_distance should be float, got {type(distance)}"
|
|
||||||
)
|
|
||||||
assert 0 <= distance <= 1, (
|
|
||||||
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
assert 0 <= node1_distance <= 1, (
|
|
||||||
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
assert 0 <= node2_distance <= 1, (
|
|
||||||
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# --- Search operations + graph side effects ---
|
||||||
completion_gk = await cognee.search(
|
completion_gk = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text="Where is germany located, next to which country?",
|
query_text="Where is germany located, next to which country?",
|
||||||
|
|
@ -164,6 +214,26 @@ async def main():
|
||||||
query_text="Next to which country is Germany located?",
|
query_text="Next to which country is Germany located?",
|
||||||
save_interaction=True,
|
save_interaction=True,
|
||||||
)
|
)
|
||||||
|
completion_chunks = await cognee.search(
|
||||||
|
query_type=SearchType.CHUNKS,
|
||||||
|
query_text="Germany",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_summaries = await cognee.search(
|
||||||
|
query_type=SearchType.SUMMARIES,
|
||||||
|
query_text="Germany",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_rag = await cognee.search(
|
||||||
|
query_type=SearchType.RAG_COMPLETION,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
completion_temporal = await cognee.search(
|
||||||
|
query_type=SearchType.TEMPORAL,
|
||||||
|
query_text="Next to which country is Germany located?",
|
||||||
|
save_interaction=False,
|
||||||
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
|
|
@ -171,134 +241,217 @@ async def main():
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, search_results in [
|
# Snapshot after all E2E operations above (used by assertion-only tests).
|
||||||
("GRAPH_COMPLETION", completion_gk),
|
graph_snapshot = await (await get_graph_engine()).get_graph_data()
|
||||||
("GRAPH_COMPLETION_COT", completion_cot),
|
|
||||||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
|
||||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
|
||||||
("TRIPLET_COMPLETION", completion_triplet),
|
|
||||||
]:
|
|
||||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
||||||
assert len(search_results) == 1, (
|
|
||||||
f"{name}: expected single-element list, got {len(search_results)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from cognee.context_global_variables import backend_access_control_enabled
|
return {
|
||||||
|
"graph_edges": edges,
|
||||||
|
"triplet_collection": collection,
|
||||||
|
"vector_collection_edges_count": len(collection),
|
||||||
|
"graph_edges_count": len(edges),
|
||||||
|
"contexts": contexts,
|
||||||
|
"triplets": triplets,
|
||||||
|
"search_results": {
|
||||||
|
"graph_completion": completion_gk,
|
||||||
|
"graph_completion_cot": completion_cot,
|
||||||
|
"graph_completion_context_extension": completion_ext,
|
||||||
|
"graph_summary_completion": completion_sum,
|
||||||
|
"triplet_completion": completion_triplet,
|
||||||
|
"chunks": completion_chunks,
|
||||||
|
"summaries": completion_summaries,
|
||||||
|
"rag_completion": completion_rag,
|
||||||
|
"temporal": completion_temporal,
|
||||||
|
},
|
||||||
|
"graph_snapshot": graph_snapshot,
|
||||||
|
}
|
||||||
|
|
||||||
if backend_access_control_enabled():
|
|
||||||
text = search_results[0]["search_result"][0]
|
|
||||||
else:
|
|
||||||
text = search_results[0]
|
|
||||||
assert isinstance(text, str), f"{name}: element should be a string"
|
|
||||||
assert text.strip(), f"{name}: string should not be empty"
|
|
||||||
assert "netherlands" in text.lower(), (
|
|
||||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
@pytest_asyncio.fixture(scope="session")
|
||||||
graph = await graph_engine.get_graph_data()
|
async def feedback_state():
|
||||||
|
"""Feedback-weight scenario computed once (fresh environment)."""
|
||||||
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
await setup_test_environment_for_feedback()
|
||||||
|
|
||||||
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
|
||||||
|
|
||||||
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
|
||||||
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
|
||||||
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert there is exactly two CogneeUserFeedback nodes.
|
|
||||||
assert type_counts.get("CogneeUserFeedback", 0) == 2, (
|
|
||||||
f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert there is exactly two NodeSet.
|
|
||||||
assert type_counts.get("NodeSet", 0) == 2, (
|
|
||||||
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
|
|
||||||
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
|
|
||||||
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are exactly 2 'gives_feedback_to' edges.
|
|
||||||
assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
|
|
||||||
f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that there are at least 6 'belongs_to_set' edges.
|
|
||||||
assert edge_type_counts.get("belongs_to_set", 0) == 6, (
|
|
||||||
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = graph[0]
|
|
||||||
|
|
||||||
required_fields_user_interaction = {"question", "answer", "context"}
|
|
||||||
required_fields_feedback = {"feedback", "sentiment"}
|
|
||||||
|
|
||||||
for node_id, data in nodes:
|
|
||||||
if data.get("type") == "CogneeUserInteraction":
|
|
||||||
assert required_fields_user_interaction.issubset(data.keys()), (
|
|
||||||
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in required_fields_user_interaction:
|
|
||||||
value = data[field]
|
|
||||||
assert isinstance(value, str) and value.strip(), (
|
|
||||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("type") == "CogneeUserFeedback":
|
|
||||||
assert required_fields_feedback.issubset(data.keys()), (
|
|
||||||
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in required_fields_feedback:
|
|
||||||
value = data[field]
|
|
||||||
assert isinstance(value, str) and value.strip(), (
|
|
||||||
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
|
|
||||||
await cognee.add(text_1, dataset_name)
|
|
||||||
|
|
||||||
await cognee.add([text], dataset_name)
|
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text="Next to which country is Germany located?",
|
query_text="Next to which country is Germany located?",
|
||||||
save_interaction=True,
|
save_interaction=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
query_text="This was the best answer I've ever seen",
|
query_text="This was the best answer I've ever seen",
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.search(
|
await cognee.search(
|
||||||
query_type=SearchType.FEEDBACK,
|
query_type=SearchType.FEEDBACK,
|
||||||
query_text="Wow the correctness of this answer blows my mind",
|
query_text="Wow the correctness of this answer blows my mind",
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
graph = await graph_engine.get_graph_data()
|
graph = await graph_engine.get_graph_data()
|
||||||
|
return {"graph_snapshot": graph}
|
||||||
|
|
||||||
edges = graph[1]
|
|
||||||
|
|
||||||
for from_node, to_node, relationship_name, properties in edges:
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_graph_vector_consistency(e2e_state):
|
||||||
|
"""Graph and vector stores contain the same triplet edges."""
|
||||||
|
assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_retriever_contexts(e2e_state):
|
||||||
|
"""All retrievers return non-empty, well-typed contexts."""
|
||||||
|
contexts = e2e_state["contexts"]
|
||||||
|
|
||||||
|
for name in [
|
||||||
|
"graph_completion",
|
||||||
|
"graph_completion_cot",
|
||||||
|
"graph_completion_context_extension",
|
||||||
|
"graph_summary_completion",
|
||||||
|
]:
|
||||||
|
ctx = contexts[name]
|
||||||
|
assert isinstance(ctx, list), f"{name}: Context should be a list"
|
||||||
|
assert ctx, f"{name}: Context should not be empty"
|
||||||
|
ctx_text = await resolve_edges_to_text(ctx)
|
||||||
|
lower = ctx_text.lower()
|
||||||
|
assert "germany" in lower or "netherlands" in lower, (
|
||||||
|
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_ctx = contexts["triplet"]
|
||||||
|
assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
|
||||||
|
assert triplet_ctx.strip(), "triplet: Context should not be empty"
|
||||||
|
|
||||||
|
chunks_ctx = contexts["chunks"]
|
||||||
|
assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
|
||||||
|
assert chunks_ctx, "chunks: Context should not be empty"
|
||||||
|
chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
|
||||||
|
assert "germany" in chunks_text or "netherlands" in chunks_text
|
||||||
|
|
||||||
|
summaries_ctx = contexts["summaries"]
|
||||||
|
assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
|
||||||
|
assert summaries_ctx, "summaries: Context should not be empty"
|
||||||
|
assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
|
||||||
|
|
||||||
|
rag_ctx = contexts["rag_completion"]
|
||||||
|
assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
|
||||||
|
assert rag_ctx.strip(), "rag_completion: Context should not be empty"
|
||||||
|
|
||||||
|
temporal_ctx = contexts["temporal"]
|
||||||
|
assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
|
||||||
|
assert temporal_ctx.strip(), "temporal: Context should not be empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
||||||
|
"""Graph retriever triplets include sane vector_distance metadata."""
|
||||||
|
for name, triplets in e2e_state["triplets"].items():
|
||||||
|
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
||||||
|
assert triplets, f"{name}: Triplets list should not be empty"
|
||||||
|
for edge in triplets:
|
||||||
|
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||||
|
distance = edge.attributes.get("vector_distance")
|
||||||
|
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||||
|
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||||
|
assert isinstance(distance, float), f"{name}: vector_distance should be float"
|
||||||
|
assert 0 <= distance <= 1
|
||||||
|
assert 0 <= node1_distance <= 1
|
||||||
|
assert 0 <= node2_distance <= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_search_results_and_wrappers(e2e_state):
|
||||||
|
"""Search returns expected shapes across search types and access modes."""
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
sr = e2e_state["search_results"]
|
||||||
|
|
||||||
|
# Completion-like search types: validate wrapper + content
|
||||||
|
for name in [
|
||||||
|
"graph_completion",
|
||||||
|
"graph_completion_cot",
|
||||||
|
"graph_completion_context_extension",
|
||||||
|
"graph_summary_completion",
|
||||||
|
"triplet_completion",
|
||||||
|
"rag_completion",
|
||||||
|
"temporal",
|
||||||
|
]:
|
||||||
|
search_results = sr[name]
|
||||||
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||||
|
assert len(search_results) == 1, f"{name}: expected single-element list"
|
||||||
|
|
||||||
|
if backend_access_control_enabled():
|
||||||
|
wrapper = search_results[0]
|
||||||
|
assert isinstance(wrapper, dict), (
|
||||||
|
f"{name}: expected wrapper dict in access control mode"
|
||||||
|
)
|
||||||
|
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
|
||||||
|
assert wrapper.get("dataset_name") == "test_dataset"
|
||||||
|
assert "graphs" in wrapper
|
||||||
|
text = wrapper["search_result"][0]
|
||||||
|
else:
|
||||||
|
text = search_results[0]
|
||||||
|
|
||||||
|
assert isinstance(text, str) and text.strip()
|
||||||
|
assert "netherlands" in text.lower()
|
||||||
|
|
||||||
|
# Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
|
||||||
|
for name in ["chunks", "summaries"]:
|
||||||
|
search_results = sr[name]
|
||||||
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||||
|
assert search_results, f"{name}: should not be empty"
|
||||||
|
|
||||||
|
first = search_results[0]
|
||||||
|
assert isinstance(first, dict), f"{name}: expected dict entries"
|
||||||
|
|
||||||
|
payloads = search_results
|
||||||
|
if "search_result" in first and "text" not in first:
|
||||||
|
payloads = (first.get("search_result") or [None])[0]
|
||||||
|
|
||||||
|
assert isinstance(payloads, list) and payloads
|
||||||
|
assert isinstance(payloads[0], dict)
|
||||||
|
assert str(payloads[0].get("text", "")).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
|
||||||
|
"""Search interactions create expected graph nodes/edges and required fields."""
|
||||||
|
graph = e2e_state["graph_snapshot"]
|
||||||
|
nodes, edges = graph
|
||||||
|
|
||||||
|
type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
|
||||||
|
edge_type_counts = Counter(edge_type[2] for edge_type in edges)
|
||||||
|
|
||||||
|
assert type_counts.get("CogneeUserInteraction", 0) == 4
|
||||||
|
assert type_counts.get("CogneeUserFeedback", 0) == 2
|
||||||
|
assert type_counts.get("NodeSet", 0) == 2
|
||||||
|
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
|
||||||
|
assert edge_type_counts.get("gives_feedback_to", 0) == 2
|
||||||
|
assert edge_type_counts.get("belongs_to_set", 0) >= 6
|
||||||
|
|
||||||
|
required_fields_user_interaction = {"question", "answer", "context"}
|
||||||
|
required_fields_feedback = {"feedback", "sentiment"}
|
||||||
|
|
||||||
|
for node_id, data in nodes:
|
||||||
|
if data.get("type") == "CogneeUserInteraction":
|
||||||
|
assert required_fields_user_interaction.issubset(data.keys())
|
||||||
|
for field in required_fields_user_interaction:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip()
|
||||||
|
|
||||||
|
if data.get("type") == "CogneeUserFeedback":
|
||||||
|
assert required_fields_feedback.issubset(data.keys())
|
||||||
|
for field in required_fields_feedback:
|
||||||
|
value = data[field]
|
||||||
|
assert isinstance(value, str) and value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_feedback_weight_calculation(feedback_state):
|
||||||
|
"""Positive feedback increases used_graph_element_to_answer feedback_weight."""
|
||||||
|
_nodes, edges = feedback_state["graph_snapshot"]
|
||||||
|
for _from_node, _to_node, relationship_name, properties in edges:
|
||||||
if relationship_name == "used_graph_element_to_answer":
|
if relationship_name == "used_graph_element_to_answer":
|
||||||
assert properties["feedback_weight"] >= 6, (
|
assert properties["feedback_weight"] >= 6, (
|
||||||
"Feedback weight calculation is not correct, it should be more then 6."
|
"Feedback weight calculation is not correct, it should be more then 6."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,28 @@
|
||||||
import pytest
|
import pytest
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from unittest.mock import patch, Mock, AsyncMock
|
from unittest.mock import Mock
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import importlib
|
|
||||||
from cognee.api.client import app
|
from cognee.api.client import app
|
||||||
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
|
||||||
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def test_client():
|
||||||
|
# Keep a single TestClient (and event loop) for the whole module.
|
||||||
|
# Re-creating TestClient repeatedly can break async DB connections (asyncpg loop mismatch).
|
||||||
|
with TestClient(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client(test_client, mock_default_user):
|
||||||
return TestClient(app)
|
async def override_get_authenticated_user():
|
||||||
|
return mock_default_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
|
||||||
|
yield test_client
|
||||||
|
app.dependency_overrides.pop(get_authenticated_user, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -32,12 +43,8 @@ def mock_default_user():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_ontology_success(client):
|
||||||
def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
|
|
||||||
"""Test successful ontology upload"""
|
"""Test successful ontology upload"""
|
||||||
import json
|
|
||||||
|
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
ontology_content = (
|
ontology_content = (
|
||||||
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||||
)
|
)
|
||||||
|
|
@ -46,7 +53,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/v1/ontologies",
|
"/api/v1/ontologies",
|
||||||
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
||||||
data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
|
data={"ontology_key": unique_key, "description": "Test"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
@ -55,10 +62,8 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
|
||||||
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_ontology_invalid_file(client):
|
||||||
def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
|
|
||||||
"""Test 400 response for non-.owl files"""
|
"""Test 400 response for non-.owl files"""
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/v1/ontologies",
|
"/api/v1/ontologies",
|
||||||
|
|
@ -68,14 +73,10 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_ontology_missing_data(client):
|
||||||
def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
|
|
||||||
"""Test 400 response for missing file or key"""
|
"""Test 400 response for missing file or key"""
|
||||||
import json
|
|
||||||
|
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
# Missing file
|
# Missing file
|
||||||
response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
|
response = client.post("/api/v1/ontologies", data={"ontology_key": "test"})
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
# Missing key
|
# Missing key
|
||||||
|
|
@ -85,34 +86,25 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_ontology_without_auth_header(client):
|
||||||
def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
|
"""Test behavior when no explicit authentication header is provided."""
|
||||||
"""Test behavior when default user is provided (no explicit authentication)"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/v1/ontologies",
|
"/api/v1/ontologies",
|
||||||
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
|
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
|
||||||
data={"ontology_key": json.dumps([unique_key])},
|
data={"ontology_key": unique_key},
|
||||||
)
|
)
|
||||||
|
|
||||||
# The current system provides a default user when no explicit authentication is given
|
|
||||||
# This test verifies the system works with conditional authentication
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
|
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
|
||||||
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_multiple_ontologies_in_single_request_is_rejected(client):
|
||||||
def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
|
"""Uploading multiple ontology files in a single request should fail."""
|
||||||
"""Test uploading multiple ontology files in single request"""
|
|
||||||
import io
|
import io
|
||||||
|
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
# Create mock files
|
|
||||||
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||||
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||||
|
|
||||||
|
|
@ -120,45 +112,34 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_
|
||||||
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
|
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
|
||||||
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
|
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
|
||||||
]
|
]
|
||||||
data = {
|
data = {"ontology_key": "vehicles", "description": "Base vehicles"}
|
||||||
"ontology_key": '["vehicles", "manufacturers"]',
|
|
||||||
"descriptions": '["Base vehicles", "Car manufacturers"]',
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 400
|
||||||
result = response.json()
|
assert "Only one ontology_file is allowed" in response.json()["error"]
|
||||||
assert "uploaded_ontologies" in result
|
|
||||||
assert len(result["uploaded_ontologies"]) == 2
|
|
||||||
assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
|
|
||||||
assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
|
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_endpoint_rejects_array_style_fields(client):
|
||||||
def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
|
"""Array-style form values should be rejected (no backwards compatibility)."""
|
||||||
"""Test that upload endpoint accepts array parameters"""
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||||
|
|
||||||
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
|
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
|
||||||
data = {
|
data = {
|
||||||
"ontology_key": json.dumps(["single_key"]),
|
"ontology_key": json.dumps(["single_key"]),
|
||||||
"descriptions": json.dumps(["Single ontology"]),
|
"description": json.dumps(["Single ontology"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 400
|
||||||
result = response.json()
|
assert "ontology_key must be a string" in response.json()["error"]
|
||||||
assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
|
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_cognify_with_multiple_ontologies(client):
|
||||||
def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
|
|
||||||
"""Test cognify endpoint accepts multiple ontology keys"""
|
"""Test cognify endpoint accepts multiple ontology keys"""
|
||||||
payload = {
|
payload = {
|
||||||
"datasets": ["test_dataset"],
|
"datasets": ["test_dataset"],
|
||||||
|
|
@ -172,14 +153,11 @@ def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_de
|
||||||
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
|
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_complete_multifile_workflow(client):
|
||||||
def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
|
"""Test workflow: upload ontologies one-by-one → cognify with multiple keys"""
|
||||||
"""Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
|
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
|
|
||||||
mock_get_default_user.return_value = mock_default_user
|
# Step 1: Upload two ontologies (one-by-one)
|
||||||
# Step 1: Upload multiple ontologies
|
|
||||||
file1_content = b"""<?xml version="1.0"?>
|
file1_content = b"""<?xml version="1.0"?>
|
||||||
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||||
xmlns:owl="http://www.w3.org/2002/07/owl#">
|
xmlns:owl="http://www.w3.org/2002/07/owl#">
|
||||||
|
|
@ -192,17 +170,21 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
|
||||||
<owl:Class rdf:ID="Manufacturer"/>
|
<owl:Class rdf:ID="Manufacturer"/>
|
||||||
</rdf:RDF>"""
|
</rdf:RDF>"""
|
||||||
|
|
||||||
files = [
|
upload_response_1 = client.post(
|
||||||
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
|
"/api/v1/ontologies",
|
||||||
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
|
files=[("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml"))],
|
||||||
]
|
data={"ontology_key": "vehicles", "description": "Vehicle ontology"},
|
||||||
data = {
|
)
|
||||||
"ontology_key": json.dumps(["vehicles", "manufacturers"]),
|
assert upload_response_1.status_code == 200
|
||||||
"descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
upload_response = client.post("/api/v1/ontologies", files=files, data=data)
|
upload_response_2 = client.post(
|
||||||
assert upload_response.status_code == 200
|
"/api/v1/ontologies",
|
||||||
|
files=[
|
||||||
|
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml"))
|
||||||
|
],
|
||||||
|
data={"ontology_key": "manufacturers", "description": "Manufacturer ontology"},
|
||||||
|
)
|
||||||
|
assert upload_response_2.status_code == 200
|
||||||
|
|
||||||
# Step 2: Verify ontologies are listed
|
# Step 2: Verify ontologies are listed
|
||||||
list_response = client.get("/api/v1/ontologies")
|
list_response = client.get("/api/v1/ontologies")
|
||||||
|
|
@ -223,44 +205,42 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
|
||||||
assert cognify_response.status_code != 400 # Not a validation error
|
assert cognify_response.status_code != 400 # Not a validation error
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_upload_error_handling(client):
|
||||||
def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
|
"""Test error handling for invalid uploads (single-file endpoint)."""
|
||||||
"""Test error handling for invalid multifile uploads"""
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Test mismatched array lengths
|
# Array-style key should be rejected
|
||||||
file_content = b"<rdf:RDF></rdf:RDF>"
|
file_content = b"<rdf:RDF></rdf:RDF>"
|
||||||
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
|
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
|
||||||
data = {
|
data = {
|
||||||
"ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
|
"ontology_key": json.dumps(["key1", "key2"]),
|
||||||
"descriptions": json.dumps(["desc1"]),
|
"description": "desc1",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "Number of keys must match number of files" in response.json()["error"]
|
assert "ontology_key must be a string" in response.json()["error"]
|
||||||
|
|
||||||
# Test duplicate keys
|
# Duplicate key should be rejected
|
||||||
files = [
|
response_1 = client.post(
|
||||||
("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
|
"/api/v1/ontologies",
|
||||||
("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
|
files=[("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml"))],
|
||||||
]
|
data={"ontology_key": "duplicate", "description": "desc1"},
|
||||||
data = {
|
)
|
||||||
"ontology_key": json.dumps(["duplicate", "duplicate"]),
|
assert response_1.status_code == 200
|
||||||
"descriptions": json.dumps(["desc1", "desc2"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
response_2 = client.post(
|
||||||
assert response.status_code == 400
|
"/api/v1/ontologies",
|
||||||
assert "Duplicate ontology keys not allowed" in response.json()["error"]
|
files=[("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml"))],
|
||||||
|
data={"ontology_key": "duplicate", "description": "desc2"},
|
||||||
|
)
|
||||||
|
assert response_2.status_code == 400
|
||||||
|
assert "already exists" in response_2.json()["error"]
|
||||||
|
|
||||||
|
|
||||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
def test_cognify_missing_ontology_key(client):
|
||||||
def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
|
|
||||||
"""Test cognify with non-existent ontology key"""
|
"""Test cognify with non-existent ontology key"""
|
||||||
mock_get_default_user.return_value = mock_default_user
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"datasets": ["test_dataset"],
|
"datasets": ["test_dataset"],
|
||||||
"ontology_key": ["nonexistent_key"],
|
"ontology_key": ["nonexistent_key"],
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
||||||
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MOCK_HOTPOT_CORPUS = [
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"question": "Next to which country is Germany located?",
|
||||||
|
"answer": "Netherlands",
|
||||||
|
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||||
|
"level": "easy",
|
||||||
|
"type": "comparison",
|
||||||
|
"context": [
|
||||||
|
["Germany", ["Germany is in Europe."]],
|
||||||
|
["Netherlands", ["The Netherlands borders Germany."]],
|
||||||
|
],
|
||||||
|
"supporting_facts": [["Netherlands", 0]],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_CLASSES = [
|
ADAPTER_CLASSES = [
|
||||||
HotpotQAAdapter,
|
HotpotQAAdapter,
|
||||||
|
|
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
result = adapter.load_corpus()
|
result = adapter.load_corpus()
|
||||||
|
|
||||||
|
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||||
|
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
adapter = AdapterClass()
|
||||||
|
result = adapter.load_corpus()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
result = adapter.load_corpus()
|
result = adapter.load_corpus()
|
||||||
|
|
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
||||||
):
|
):
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
|
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||||
|
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
adapter = AdapterClass()
|
||||||
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
else:
|
else:
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,38 @@ import pytest
|
||||||
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
||||||
|
|
||||||
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
||||||
|
|
||||||
|
MOCK_HOTPOT_CORPUS = [
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"question": "Next to which country is Germany located?",
|
||||||
|
"answer": "Netherlands",
|
||||||
|
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||||
|
"level": "easy",
|
||||||
|
"type": "comparison",
|
||||||
|
"context": [
|
||||||
|
["Germany", ["Germany is in Europe."]],
|
||||||
|
["Netherlands", ["The Netherlands borders Germany."]],
|
||||||
|
],
|
||||||
|
"supporting_facts": [["Netherlands", 0]],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("benchmark", benchmark_options)
|
@pytest.mark.parametrize("benchmark", benchmark_options)
|
||||||
def test_corpus_builder_load_corpus(benchmark):
|
def test_corpus_builder_load_corpus(benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
|
else:
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
|
|
||||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||||
assert len(questions) <= 2, (
|
assert len(questions) <= 2, (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
||||||
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
||||||
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||||
questions = await corpus_builder.build_corpus(limit=limit)
|
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
|
else:
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
|
|
||||||
assert len(questions) <= 2, (
|
assert len(questions) <= 2, (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
|
|
||||||
version = "0.5.0.dev0"
|
version = "0.5.0.dev1"
|
||||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Vasilije Markovic" },
|
{ name = "Vasilije Markovic" },
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.10, <3.14"
|
requires-python = ">=3.10, <3.14"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
||||||
|
|
@ -946,7 +946,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
version = "0.5.0.dev0"
|
version = "0.5.0.dev1"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue