diff --git a/.env.template b/.env.template
index 7dcd4f346..61853b983 100644
--- a/.env.template
+++ b/.env.template
@@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="16384"
+# Instructor's modes determine how structured data is requested from and extracted from LLM responses
+# You can change this type (i.e. mode) via this env variable
+# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
+LLM_INSTRUCTOR_MODE=""
EMBEDDING_PROVIDER="openai"
EMBEDDING_MODEL="openai/text-embedding-3-large"
@@ -169,8 +173,9 @@ REQUIRE_AUTHENTICATION=False
# Vector: LanceDB
# Graph: KuzuDB
#
-# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
-ENABLE_BACKEND_ACCESS_CONTROL=False
+# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
+# Disable mode when using not supported graph/vector databases.
+ENABLE_BACKEND_ACCESS_CONTROL=True
################################################################################
# ☁️ Cloud Sync Settings
diff --git a/.github/actions/cognee_setup/action.yml b/.github/actions/cognee_setup/action.yml
index 4017d524b..3f5726015 100644
--- a/.github/actions/cognee_setup/action.yml
+++ b/.github/actions/cognee_setup/action.yml
@@ -42,3 +42,8 @@ runs:
done
fi
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS
+
+ - name: Add telemetry identifier for telemetry test and in case telemetry is enabled by accident
+ shell: bash
+ run: |
+ echo "test-machine" > .anon_id
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 0e6f74188..be9d219c1 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -6,6 +6,14 @@ Please provide a clear, human-generated description of the changes in this PR.
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
-->
+## Acceptance Criteria
+
+
## Type of Change
- [ ] Bug fix (non-breaking change that fixes an issue)
diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml
index b7f324310..98ced21dc 100644
--- a/.github/workflows/basic_tests.yml
+++ b/.github/workflows/basic_tests.yml
@@ -75,6 +75,7 @@ jobs:
name: Run Unit Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -104,6 +105,7 @@ jobs:
name: Run Integration Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -132,6 +134,7 @@ jobs:
name: Run Simple Examples
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -161,6 +164,7 @@ jobs:
name: Run Simple Examples BAML
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
BAML_LLM_PROVIDER: openai
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
@@ -198,6 +202,7 @@ jobs:
name: Run Basic Graph Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
diff --git a/.github/workflows/cli_tests.yml b/.github/workflows/cli_tests.yml
index 958d341ae..d4f8e5ac0 100644
--- a/.github/workflows/cli_tests.yml
+++ b/.github/workflows/cli_tests.yml
@@ -39,6 +39,7 @@ jobs:
name: CLI Unit Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -66,6 +67,7 @@ jobs:
name: CLI Integration Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -93,6 +95,7 @@ jobs:
name: CLI Functionality Tests
runs-on: ubuntu-22.04
env:
+ ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
diff --git a/.github/workflows/db_examples_tests.yml b/.github/workflows/db_examples_tests.yml
index 51ac9a82a..c58bc48ef 100644
--- a/.github/workflows/db_examples_tests.yml
+++ b/.github/workflows/db_examples_tests.yml
@@ -60,7 +60,7 @@ jobs:
- name: Run Neo4j Example
env:
- ENV: dev
+ ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@@ -95,7 +95,7 @@ jobs:
- name: Run Kuzu Example
env:
- ENV: dev
+ ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@@ -141,7 +141,7 @@ jobs:
- name: Run PGVector Example
env:
- ENV: dev
+ ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml
index 70a4b56e6..3dea2548c 100644
--- a/.github/workflows/e2e_tests.yml
+++ b/.github/workflows/e2e_tests.yml
@@ -226,7 +226,7 @@ jobs:
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- - name: Run parallel databases test
+ - name: Run permissions test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@@ -239,6 +239,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_permissions.py
+ test-multi-tenancy:
+ name: Test multi tenancy with different situations in Cognee
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+
+ - name: Run multi tenancy test
+ env:
+ ENV: 'dev'
+ LLM_MODEL: ${{ secrets.LLM_MODEL }}
+ LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
+ LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+ run: uv run python ./cognee/tests/test_multi_tenancy.py
+
test-graph-edges:
name: Test graph edge ingestion
runs-on: ubuntu-22.04
@@ -308,7 +333,7 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- - name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres)
+ - name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres/Redis)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@@ -321,6 +346,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
+ CACHE_BACKEND: 'redis'
SHARED_KUZU_LOCK: true
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
@@ -386,8 +412,8 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
- run_conversation_sessions_test:
- name: Conversation sessions test
+ run_conversation_sessions_test_redis:
+ name: Conversation sessions test (Redis)
runs-on: ubuntu-latest
defaults:
run:
@@ -427,7 +453,60 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- - name: Run Conversation session tests
+ - name: Run Conversation session tests (Redis)
+ env:
+ ENV: 'dev'
+ LLM_MODEL: ${{ secrets.LLM_MODEL }}
+ LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
+ LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+ GRAPH_DATABASE_PROVIDER: 'kuzu'
+ CACHING: true
+ CACHE_BACKEND: 'redis'
+ DB_PROVIDER: 'postgres'
+ DB_NAME: 'cognee_db'
+ DB_HOST: '127.0.0.1'
+ DB_PORT: 5432
+ DB_USERNAME: cognee
+ DB_PASSWORD: cognee
+ run: uv run python ./cognee/tests/test_conversation_history.py
+
+ run_conversation_sessions_test_fs:
+ name: Conversation sessions test (FS)
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ shell: bash
+ services:
+ postgres:
+ image: pgvector/pgvector:pg17
+ env:
+ POSTGRES_USER: cognee
+ POSTGRES_PASSWORD: cognee
+ POSTGRES_DB: cognee_db
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+ ports:
+ - 5432:5432
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+ extra-dependencies: "postgres"
+
+ - name: Run Conversation session tests (FS)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@@ -440,6 +519,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
+ CACHE_BACKEND: 'fs'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml
index 57bc88157..f7cc278cb 100644
--- a/.github/workflows/examples_tests.yml
+++ b/.github/workflows/examples_tests.yml
@@ -21,6 +21,7 @@ jobs:
- name: Run Multimedia Example
env:
+ ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: uv run python ./examples/python/multimedia_example.py
@@ -40,6 +41,7 @@ jobs:
- name: Run Evaluation Framework Example
env:
+ ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@@ -69,6 +71,7 @@ jobs:
- name: Run Descriptive Graph Metrics Example
env:
+ ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@@ -99,6 +102,7 @@ jobs:
- name: Run Dynamic Steps Tests
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -124,6 +128,7 @@ jobs:
- name: Run Temporal Example
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -149,6 +154,7 @@ jobs:
- name: Run Ontology Demo Example
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -174,6 +180,7 @@ jobs:
- name: Run Agentic Reasoning Example
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -199,6 +206,7 @@ jobs:
- name: Run Memify Tests
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -210,6 +218,32 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/memify_coding_agent_example.py
+ test-custom-pipeline:
+ name: Run Custom Pipeline Example
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+
+ - name: Run Custom Pipeline Example
+ env:
+ ENV: 'dev'
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ LLM_MODEL: ${{ secrets.LLM_MODEL }}
+ LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
+ LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+ run: uv run python ./examples/python/run_custom_pipeline_example.py
+
test-permissions-example:
name: Run Permissions Example
runs-on: ubuntu-22.04
@@ -224,6 +258,7 @@ jobs:
- name: Run Memify Tests
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@@ -249,6 +284,7 @@ jobs:
- name: Run Docling Test
env:
+ ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
diff --git a/.github/workflows/load_tests.yml b/.github/workflows/load_tests.yml
new file mode 100644
index 000000000..f5b64d8ce
--- /dev/null
+++ b/.github/workflows/load_tests.yml
@@ -0,0 +1,70 @@
+name: Load tests
+
+permissions:
+ contents: read
+
+on:
+ workflow_dispatch:
+ workflow_call:
+ secrets:
+ LLM_MODEL:
+ required: true
+ LLM_ENDPOINT:
+ required: true
+ LLM_API_KEY:
+ required: true
+ LLM_API_VERSION:
+ required: true
+ EMBEDDING_MODEL:
+ required: true
+ EMBEDDING_ENDPOINT:
+ required: true
+ EMBEDDING_API_KEY:
+ required: true
+ EMBEDDING_API_VERSION:
+ required: true
+ OPENAI_API_KEY:
+ required: true
+ AWS_ACCESS_KEY_ID:
+ required: true
+ AWS_SECRET_ACCESS_KEY:
+ required: true
+
+jobs:
+ test-load:
+ name: Test Load
+ runs-on: ubuntu-22.04
+ timeout-minutes: 60
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Cognee Setup
+ uses: ./.github/actions/cognee_setup
+ with:
+ python-version: '3.11.x'
+ extra-dependencies: "aws"
+
+ - name: Verify File Descriptor Limit
+ run: ulimit -n
+
+ - name: Run Load Test
+ env:
+ ENV: 'dev'
+ ENABLE_BACKEND_ACCESS_CONTROL: True
+ LLM_MODEL: ${{ secrets.LLM_MODEL }}
+ LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
+ LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+ STORAGE_BACKEND: s3
+ AWS_REGION: eu-west-1
+ AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
+ run: uv run python ./cognee/tests/test_load.py
+
+
diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml
new file mode 100644
index 000000000..6ac3ca515
--- /dev/null
+++ b/.github/workflows/release_test.yml
@@ -0,0 +1,17 @@
+# Long-running, heavy and resource-consuming tests for release validation
+name: Release Test Workflow
+
+permissions:
+ contents: read
+
+on:
+ workflow_dispatch:
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ load-tests:
+ name: Load Tests
+ uses: ./.github/workflows/load_tests.yml
+ secrets: inherit
\ No newline at end of file
diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml
index e3e46dd97..118c1c06c 100644
--- a/.github/workflows/search_db_tests.yml
+++ b/.github/workflows/search_db_tests.yml
@@ -84,6 +84,7 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
+ ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
@@ -135,6 +136,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'pgvector'
+ ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
@@ -197,6 +199,7 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
+ ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432
diff --git a/.github/workflows/test_different_operating_systems.yml b/.github/workflows/test_different_operating_systems.yml
index 64f1a14f9..02651b474 100644
--- a/.github/workflows/test_different_operating_systems.yml
+++ b/.github/workflows/test_different_operating_systems.yml
@@ -10,6 +10,10 @@ on:
required: false
type: string
default: '["3.10.x", "3.12.x", "3.13.x"]'
+ os:
+ required: false
+ type: string
+ default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
secrets:
LLM_PROVIDER:
required: true
@@ -40,10 +44,11 @@ jobs:
run-unit-tests:
name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ubuntu-22.04, macos-15, windows-latest]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@@ -76,10 +81,11 @@ jobs:
run-integration-tests:
name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ ubuntu-22.04, macos-15, windows-latest ]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@@ -112,10 +118,11 @@ jobs:
run-library-test:
name: Library test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ ubuntu-22.04, macos-15, windows-latest ]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@@ -148,10 +155,11 @@ jobs:
run-build-test:
name: Build test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ ubuntu-22.04, macos-15, windows-latest ]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@@ -177,10 +185,11 @@ jobs:
run-soft-deletion-test:
name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ ubuntu-22.04, macos-15, windows-latest ]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@@ -214,10 +223,11 @@ jobs:
run-hard-deletion-test:
name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
- os: [ ubuntu-22.04, macos-15, windows-latest ]
+ os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml
index 5c1597a93..be1e354fc 100644
--- a/.github/workflows/test_suites.yml
+++ b/.github/workflows/test_suites.yml
@@ -1,4 +1,6 @@
name: Test Suites
+permissions:
+ contents: read
on:
push:
@@ -80,12 +82,22 @@ jobs:
uses: ./.github/workflows/notebooks_tests.yml
secrets: inherit
- different-operating-systems-tests:
- name: Operating System and Python Tests
+ different-os-tests-basic:
+ name: OS and Python Tests Ubuntu
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml
with:
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
+ os: '["ubuntu-22.04"]'
+ secrets: inherit
+
+ different-os-tests-extended:
+ name: OS and Python Tests Extended
+ needs: [basic-tests, e2e-tests]
+ uses: ./.github/workflows/test_different_operating_systems.yml
+ with:
+ python-versions: '["3.13.x"]'
+ os: '["macos-15", "windows-latest"]'
secrets: inherit
# Matrix-based vector database tests
@@ -135,7 +147,8 @@ jobs:
e2e-tests,
graph-db-tests,
notebook-tests,
- different-operating-systems-tests,
+ different-os-tests-basic,
+ different-os-tests-extended,
vector-db-tests,
example-tests,
llm-tests,
@@ -155,7 +168,8 @@ jobs:
cli-tests,
graph-db-tests,
notebook-tests,
- different-operating-systems-tests,
+ different-os-tests-basic,
+ different-os-tests-extended,
vector-db-tests,
example-tests,
db-examples-tests,
@@ -176,7 +190,8 @@ jobs:
"${{ needs.cli-tests.result }}" == "success" &&
"${{ needs.graph-db-tests.result }}" == "success" &&
"${{ needs.notebook-tests.result }}" == "success" &&
- "${{ needs.different-operating-systems-tests.result }}" == "success" &&
+ "${{ needs.different-os-tests-basic.result }}" == "success" &&
+ "${{ needs.different-os-tests-extended.result }}" == "success" &&
"${{ needs.vector-db-tests.result }}" == "success" &&
"${{ needs.example-tests.result }}" == "success" &&
"${{ needs.db-examples-tests.result }}" == "success" &&
diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml
index 874ef6ea4..2b4a043bf 100644
--- a/.github/workflows/weighted_edges_tests.yml
+++ b/.github/workflows/weighted_edges_tests.yml
@@ -2,7 +2,7 @@ name: Weighted Edges Tests
on:
push:
- branches: [ main, weighted_edges ]
+ branches: [ main, dev, weighted_edges ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
@@ -10,7 +10,7 @@ on:
- 'examples/python/weighted_edges_example.py'
- '.github/workflows/weighted_edges_tests.yml'
pull_request:
- branches: [ main ]
+ branches: [ main, dev ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
@@ -32,7 +32,7 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
- LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
steps:
- name: Check out repository
@@ -67,14 +67,13 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
- LLM_ENDPOINT: https://api.openai.com/v1/
- LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_ENDPOINT: https://api.openai.com/v1
+ LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_VERSION: "2024-02-01"
- EMBEDDING_PROVIDER: openai
- EMBEDDING_MODEL: text-embedding-3-small
- EMBEDDING_ENDPOINT: https://api.openai.com/v1/
- EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
- EMBEDDING_API_VERSION: "2024-02-01"
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4
@@ -108,14 +107,14 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
- LLM_ENDPOINT: https://api.openai.com/v1/
- LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
+ LLM_ENDPOINT: https://api.openai.com/v1
+ LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_VERSION: "2024-02-01"
- EMBEDDING_PROVIDER: openai
- EMBEDDING_MODEL: text-embedding-3-small
- EMBEDDING_ENDPOINT: https://api.openai.com/v1/
- EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
- EMBEDDING_API_VERSION: "2024-02-01"
+ EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
+ EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
+ EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
+ EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
+
steps:
- name: Check out repository
uses: actions/checkout@v4
diff --git a/alembic/env.py b/alembic/env.py
index 1cbef65f7..8ca09968d 100644
--- a/alembic/env.py
+++ b/alembic/env.py
@@ -87,11 +87,6 @@ db_engine = get_relational_engine()
print("Using database:", db_engine.db_uri)
-if "sqlite" in db_engine.db_uri:
- from cognee.infrastructure.utils.run_sync import run_sync
-
- run_sync(db_engine.create_database())
-
config.set_section_option(
config.config_ini_section,
"SQLALCHEMY_DATABASE_URI",
diff --git a/alembic/versions/211ab850ef3d_add_sync_operations_table.py b/alembic/versions/211ab850ef3d_add_sync_operations_table.py
index 370aab1a4..30049b44b 100644
--- a/alembic/versions/211ab850ef3d_add_sync_operations_table.py
+++ b/alembic/versions/211ab850ef3d_add_sync_operations_table.py
@@ -10,6 +10,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
@@ -26,7 +27,34 @@ def upgrade() -> None:
connection = op.get_bind()
inspector = sa.inspect(connection)
+ if op.get_context().dialect.name == "postgresql":
+ syncstatus_enum = postgresql.ENUM(
+ "STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
+ )
+ syncstatus_enum.create(op.get_bind(), checkfirst=True)
+
if "sync_operations" not in inspector.get_table_names():
+ if op.get_context().dialect.name == "postgresql":
+ syncstatus = postgresql.ENUM(
+ "STARTED",
+ "IN_PROGRESS",
+ "COMPLETED",
+ "FAILED",
+ "CANCELLED",
+ name="syncstatus",
+ create_type=False,
+ )
+ else:
+ syncstatus = sa.Enum(
+ "STARTED",
+ "IN_PROGRESS",
+ "COMPLETED",
+ "FAILED",
+ "CANCELLED",
+ name="syncstatus",
+ create_type=False,
+ )
+
# Table doesn't exist, create it normally
op.create_table(
"sync_operations",
@@ -34,15 +62,7 @@ def upgrade() -> None:
sa.Column("run_id", sa.Text(), nullable=True),
sa.Column(
"status",
- sa.Enum(
- "STARTED",
- "IN_PROGRESS",
- "COMPLETED",
- "FAILED",
- "CANCELLED",
- name="syncstatus",
- create_type=False,
- ),
+ syncstatus,
nullable=True,
),
sa.Column("progress_percentage", sa.Integer(), nullable=True),
diff --git a/alembic/versions/482cd6517ce4_add_default_user.py b/alembic/versions/482cd6517ce4_add_default_user.py
index d85f0f146..c8a3dc5d5 100644
--- a/alembic/versions/482cd6517ce4_add_default_user.py
+++ b/alembic/versions/482cd6517ce4_add_default_user.py
@@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
def upgrade() -> None:
- try:
- await_only(create_default_user())
- except UserAlreadyExists:
- pass # It's fine if the default user already exists
+ pass
def downgrade() -> None:
- await_only(delete_user("default_user@example.com"))
+ pass
diff --git a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py
new file mode 100644
index 000000000..7e13898ae
--- /dev/null
+++ b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py
@@ -0,0 +1,98 @@
+"""Expand dataset database for multi user
+
+Revision ID: 76625596c5c3
+Revises: 211ab850ef3d
+Create Date: 2025-10-30 12:55:20.239562
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "76625596c5c3"
+down_revision: Union[str, None] = "c946955da633"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def _get_column(inspector, table, name, schema=None):
+ for col in inspector.get_columns(table, schema=schema):
+ if col["name"] == name:
+ return col
+ return None
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ insp = sa.inspect(conn)
+
+ vector_database_provider_column = _get_column(
+ insp, "dataset_database", "vector_database_provider"
+ )
+ if not vector_database_provider_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column(
+ "vector_database_provider",
+ sa.String(),
+ unique=False,
+ nullable=False,
+ server_default="lancedb",
+ ),
+ )
+
+ graph_database_provider_column = _get_column(
+ insp, "dataset_database", "graph_database_provider"
+ )
+ if not graph_database_provider_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column(
+ "graph_database_provider",
+ sa.String(),
+ unique=False,
+ nullable=False,
+ server_default="kuzu",
+ ),
+ )
+
+ vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
+ if not vector_database_url_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
+ )
+
+ graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
+ if not graph_database_url_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
+ )
+
+ vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
+ if not vector_database_key_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
+ )
+
+ graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
+ if not graph_database_key_column:
+ op.add_column(
+ "dataset_database",
+ sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
+ )
+
+
+def downgrade() -> None:
+ op.drop_column("dataset_database", "vector_database_provider")
+ op.drop_column("dataset_database", "graph_database_provider")
+ op.drop_column("dataset_database", "vector_database_url")
+ op.drop_column("dataset_database", "graph_database_url")
+ op.drop_column("dataset_database", "vector_database_key")
+ op.drop_column("dataset_database", "graph_database_key")
diff --git a/alembic/versions/8057ae7329c2_initial_migration.py b/alembic/versions/8057ae7329c2_initial_migration.py
index aa0ecd4b8..42e9904a8 100644
--- a/alembic/versions/8057ae7329c2_initial_migration.py
+++ b/alembic/versions/8057ae7329c2_initial_migration.py
@@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
- db_engine = get_relational_engine()
- # we might want to delete this
- await_only(db_engine.create_database())
+ pass
def downgrade() -> None:
- db_engine = get_relational_engine()
- await_only(db_engine.delete_database())
+ pass
diff --git a/alembic/versions/ab7e313804ae_permission_system_rework.py b/alembic/versions/ab7e313804ae_permission_system_rework.py
index bd69b9b41..d83f946a6 100644
--- a/alembic/versions/ab7e313804ae_permission_system_rework.py
+++ b/alembic/versions/ab7e313804ae_permission_system_rework.py
@@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
)
+def _get_column(inspector, table, name, schema=None):
+ for col in inspector.get_columns(table, schema=schema):
+ if col["name"] == name:
+ return col
+ return None
+
+
def upgrade() -> None:
conn = op.get_bind()
+ insp = sa.inspect(conn)
- # Recreate ACLs table with default permissions set to datasets instead of documents
- op.drop_table("acls")
+ dataset_id_column = _get_column(insp, "acls", "dataset_id")
+ if not dataset_id_column:
+ # Recreate ACLs table with default permissions set to datasets instead of documents
+ op.drop_table("acls")
- acls_table = op.create_table(
- "acls",
- sa.Column("id", UUID, primary_key=True, default=uuid4),
- sa.Column(
- "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
- ),
- sa.Column(
- "updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
- ),
- sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
- sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
- sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
- )
+ acls_table = op.create_table(
+ "acls",
+ sa.Column("id", UUID, primary_key=True, default=uuid4),
+ sa.Column(
+ "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
+ ),
+ sa.Column(
+ "updated_at",
+ sa.DateTime(timezone=True),
+ onupdate=lambda: datetime.now(timezone.utc),
+ ),
+ sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
+ sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
+ sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
+ )
- # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
- # definition or load what is in the database
- dataset_table = _define_dataset_table()
- datasets = conn.execute(sa.select(dataset_table)).fetchall()
+ # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
+ # definition or load what is in the database
+ dataset_table = _define_dataset_table()
+ datasets = conn.execute(sa.select(dataset_table)).fetchall()
- if not datasets:
- return
+ if not datasets:
+ return
- acl_list = []
+ acl_list = []
- for dataset in datasets:
- acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
- acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
- acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
- acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
+ for dataset in datasets:
+ acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
+ acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
+ acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
+ acl_list.append(
+ _create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
+ )
- if acl_list:
- op.bulk_insert(acls_table, acl_list)
+ if acl_list:
+ op.bulk_insert(acls_table, acl_list)
def downgrade() -> None:
diff --git a/alembic/versions/c946955da633_multi_tenant_support.py b/alembic/versions/c946955da633_multi_tenant_support.py
new file mode 100644
index 000000000..d8fccdfbf
--- /dev/null
+++ b/alembic/versions/c946955da633_multi_tenant_support.py
@@ -0,0 +1,137 @@
+"""Multi Tenant Support
+
+Revision ID: c946955da633
+Revises: 211ab850ef3d
+Create Date: 2025-11-04 18:11:09.325158
+
+"""
+
+from typing import Sequence, Union
+from datetime import datetime, timezone
+from uuid import uuid4
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "c946955da633"
+down_revision: Union[str, None] = "211ab850ef3d"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def _now():
+ return datetime.now(timezone.utc)
+
+
+def _define_user_table() -> sa.Table:
+ table = sa.Table(
+ "users",
+ sa.MetaData(),
+ sa.Column(
+ "id",
+ sa.UUID,
+ sa.ForeignKey("principals.id", ondelete="CASCADE"),
+ primary_key=True,
+ nullable=False,
+ ),
+ sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
+ )
+ return table
+
+
+def _define_dataset_table() -> sa.Table:
+ # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
+ # definition or load what is in the database
+ table = sa.Table(
+ "datasets",
+ sa.MetaData(),
+ sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
+ sa.Column("name", sa.Text),
+ sa.Column(
+ "created_at",
+ sa.DateTime(timezone=True),
+ default=lambda: datetime.now(timezone.utc),
+ ),
+ sa.Column(
+ "updated_at",
+ sa.DateTime(timezone=True),
+ onupdate=lambda: datetime.now(timezone.utc),
+ ),
+ sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
+ sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
+ )
+
+ return table
+
+
+def _get_column(inspector, table, name, schema=None):
+ for col in inspector.get_columns(table, schema=schema):
+ if col["name"] == name:
+ return col
+ return None
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ insp = sa.inspect(conn)
+
+ dataset = _define_dataset_table()
+ user = _define_user_table()
+
+ if "user_tenants" not in insp.get_table_names():
+ # Define table with all necessary columns including primary key
+ user_tenants = op.create_table(
+ "user_tenants",
+ sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
+ sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
+ sa.Column(
+ "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
+ ),
+ )
+
+ # Get all users with their tenant_id
+ user_data = conn.execute(
+ sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
+ ).fetchall()
+
+ # Insert into user_tenants table
+ if user_data:
+ op.bulk_insert(
+ user_tenants,
+ [
+ {"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
+ for user_id, tenant_id in user_data
+ ],
+ )
+
+ tenant_id_column = _get_column(insp, "datasets", "tenant_id")
+ if not tenant_id_column:
+ op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
+
+ # Build subquery, select users.tenant_id for each dataset.owner_id
+ tenant_id_from_dataset_owner = (
+ sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
+ )
+
+ if op.get_context().dialect.name == "sqlite":
+ # If column doesn't exist create new original_extension column and update from values of extension column
+ with op.batch_alter_table("datasets") as batch_op:
+ batch_op.execute(
+ dataset.update().values(
+ tenant_id=tenant_id_from_dataset_owner,
+ )
+ )
+ else:
+ conn = op.get_bind()
+ conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
+
+ op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table("user_tenants")
+ op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
+ op.drop_column("datasets", "tenant_id")
+ # ### end Alembic commands ###
diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py
index ce6dad88a..4131be988 100755
--- a/cognee-mcp/src/server.py
+++ b/cognee-mcp/src/server.py
@@ -194,7 +194,6 @@ async def cognify(
Prerequisites:
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
- - **Data Added**: Must have data previously added via `cognee.add()`
- **Vector Database**: Must be accessible for embeddings storage
- **Graph Database**: Must be accessible for relationship storage
@@ -1096,6 +1095,10 @@ async def main():
# Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
+ from cognee.modules.engine.operations.setup import setup
+
+ await setup()
+
# Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...")
migration_result = subprocess.run(
diff --git a/cognee/__init__.py b/cognee/__init__.py
index 6e4d2a903..4d150ce4e 100644
--- a/cognee/__init__.py
+++ b/cognee/__init__.py
@@ -19,6 +19,7 @@ from .api.v1.add import add
from .api.v1.delete import delete
from .api.v1.cognify import cognify
from .modules.memify import memify
+from .modules.run_custom_pipeline import run_custom_pipeline
from .api.v1.update import update
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
diff --git a/cognee/api/client.py b/cognee/api/client.py
index 6766c12de..1a08aed56 100644
--- a/cognee/api/client.py
+++ b/cognee/api/client.py
@@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
from cognee.api.v1.search.routers import get_search_router
+from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
from cognee.api.v1.memify.routers import get_memify_router
from cognee.api.v1.add.routers import get_add_router
from cognee.api.v1.delete.routers import get_delete_router
@@ -39,6 +40,8 @@ from cognee.api.v1.users.routers import (
)
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
+# Ensure application logging is configured for container stdout/stderr
+setup_logging()
logger = get_logger()
if os.getenv("ENV", "prod") == "prod":
@@ -74,6 +77,9 @@ async def lifespan(app: FastAPI):
await get_default_user()
+ # Emit a clear startup message for docker logs
+ logger.info("Backend server has started")
+
yield
@@ -258,6 +264,8 @@ app.include_router(
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
+app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
+
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py
index b2e7068b0..39dc1a3e6 100644
--- a/cognee/api/v1/add/routers/get_add_router.py
+++ b/cognee/api/v1/add/routers/get_add_router.py
@@ -82,7 +82,9 @@ def get_add_router() -> APIRouter:
datasetName,
user=user,
dataset_id=datasetId,
- node_set=node_set if node_set else None,
+ node_set=node_set
+ if node_set != [""]
+ else None, # Transform default node_set endpoint value to None
)
if isinstance(add_run, PipelineRunErrored):
diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py
index 231bbcd11..4f1497e3c 100644
--- a/cognee/api/v1/cognify/routers/get_cognify_router.py
+++ b/cognee/api/v1/cognify/routers/get_cognify_router.py
@@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
custom_prompt: Optional[str] = Field(
default="", description="Custom prompt for entity extraction and graph generation"
)
+ ontology_key: Optional[List[str]] = Field(
+ default=None, description="Reference to one or more previously uploaded ontologies"
+ )
def get_cognify_router() -> APIRouter:
@@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
+ - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
## Response
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
@@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
{
"datasets": ["research_papers", "documentation"],
"run_in_background": false,
- "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
+ "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
+ "ontology_key": ["medical_ontology_v1"]
}
```
@@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
)
from cognee.api.v1.cognify import cognify as cognee_cognify
+ from cognee.api.v1.ontologies.ontologies import OntologyService
try:
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
+ config_to_use = None
+
+ if payload.ontology_key:
+ ontology_service = OntologyService()
+ ontology_contents = ontology_service.get_ontology_contents(
+ payload.ontology_key, user
+ )
+
+ from cognee.modules.ontology.ontology_config import Config
+ from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
+ RDFLibOntologyResolver,
+ )
+ from io import StringIO
+
+ ontology_streams = [StringIO(content) for content in ontology_contents]
+ config_to_use: Config = {
+ "ontology_config": {
+ "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
+ }
+ }
cognify_run = await cognee_cognify(
datasets,
user,
+ config=config_to_use,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
)
diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py
new file mode 100644
index 000000000..b90d46c3d
--- /dev/null
+++ b/cognee/api/v1/ontologies/__init__.py
@@ -0,0 +1,4 @@
+from .ontologies import OntologyService
+from .routers.get_ontology_router import get_ontology_router
+
+__all__ = ["OntologyService", "get_ontology_router"]
diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py
new file mode 100644
index 000000000..130b4a862
--- /dev/null
+++ b/cognee/api/v1/ontologies/ontologies.py
@@ -0,0 +1,183 @@
+import os
+import json
+import tempfile
+from pathlib import Path
+from datetime import datetime, timezone
+from typing import Optional, List
+from dataclasses import dataclass
+
+
+@dataclass
+class OntologyMetadata:
+ ontology_key: str
+ filename: str
+ size_bytes: int
+ uploaded_at: str
+ description: Optional[str] = None
+
+
+class OntologyService:
+ def __init__(self):
+ pass
+
+ @property
+ def base_dir(self) -> Path:
+ return Path(tempfile.gettempdir()) / "ontologies"
+
+ def _get_user_dir(self, user_id: str) -> Path:
+ user_dir = self.base_dir / str(user_id)
+ user_dir.mkdir(parents=True, exist_ok=True)
+ return user_dir
+
+ def _get_metadata_path(self, user_dir: Path) -> Path:
+ return user_dir / "metadata.json"
+
+ def _load_metadata(self, user_dir: Path) -> dict:
+ metadata_path = self._get_metadata_path(user_dir)
+ if metadata_path.exists():
+ with open(metadata_path, "r") as f:
+ return json.load(f)
+ return {}
+
+ def _save_metadata(self, user_dir: Path, metadata: dict):
+ metadata_path = self._get_metadata_path(user_dir)
+ with open(metadata_path, "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ async def upload_ontology(
+ self, ontology_key: str, file, user, description: Optional[str] = None
+ ) -> OntologyMetadata:
+ if not file.filename.lower().endswith(".owl"):
+ raise ValueError("File must be in .owl format")
+
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ if ontology_key in metadata:
+ raise ValueError(f"Ontology key '{ontology_key}' already exists")
+
+ content = await file.read()
+ if len(content) > 10 * 1024 * 1024:
+ raise ValueError("File size exceeds 10MB limit")
+
+ file_path = user_dir / f"{ontology_key}.owl"
+ with open(file_path, "wb") as f:
+ f.write(content)
+
+ ontology_metadata = {
+ "filename": file.filename,
+ "size_bytes": len(content),
+ "uploaded_at": datetime.now(timezone.utc).isoformat(),
+ "description": description,
+ }
+ metadata[ontology_key] = ontology_metadata
+ self._save_metadata(user_dir, metadata)
+
+ return OntologyMetadata(
+ ontology_key=ontology_key,
+ filename=file.filename,
+ size_bytes=len(content),
+ uploaded_at=ontology_metadata["uploaded_at"],
+ description=description,
+ )
+
+ async def upload_ontologies(
+ self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
+ ) -> List[OntologyMetadata]:
+ """
+ Upload ontology files with their respective keys.
+
+ Args:
+ ontology_key: List of unique keys for each ontology
+ files: List of UploadFile objects (same length as keys)
+ user: Authenticated user
+ descriptions: Optional list of descriptions for each file
+
+ Returns:
+ List of OntologyMetadata objects for uploaded files
+
+ Raises:
+ ValueError: If keys duplicate, file format invalid, or array lengths don't match
+ """
+ if len(ontology_key) != len(files):
+ raise ValueError("Number of keys must match number of files")
+
+ if len(set(ontology_key)) != len(ontology_key):
+ raise ValueError("Duplicate ontology keys not allowed")
+
+ if descriptions and len(descriptions) != len(files):
+ raise ValueError("Number of descriptions must match number of files")
+
+ results = []
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ for i, (key, file) in enumerate(zip(ontology_key, files)):
+ if key in metadata:
+ raise ValueError(f"Ontology key '{key}' already exists")
+
+ if not file.filename.lower().endswith(".owl"):
+ raise ValueError(f"File '{file.filename}' must be in .owl format")
+
+ content = await file.read()
+ if len(content) > 10 * 1024 * 1024:
+ raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
+
+ file_path = user_dir / f"{key}.owl"
+ with open(file_path, "wb") as f:
+ f.write(content)
+
+ ontology_metadata = {
+ "filename": file.filename,
+ "size_bytes": len(content),
+ "uploaded_at": datetime.now(timezone.utc).isoformat(),
+ "description": descriptions[i] if descriptions else None,
+ }
+ metadata[key] = ontology_metadata
+
+ results.append(
+ OntologyMetadata(
+ ontology_key=key,
+ filename=file.filename,
+ size_bytes=len(content),
+ uploaded_at=ontology_metadata["uploaded_at"],
+ description=descriptions[i] if descriptions else None,
+ )
+ )
+
+ self._save_metadata(user_dir, metadata)
+ return results
+
+ def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
+ """
+ Retrieve ontology content for one or more keys.
+
+ Args:
+ ontology_key: List of ontology keys to retrieve (can contain single item)
+ user: Authenticated user
+
+ Returns:
+ List of ontology content strings
+
+ Raises:
+ ValueError: If any ontology key not found
+ """
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ contents = []
+ for key in ontology_key:
+ if key not in metadata:
+ raise ValueError(f"Ontology key '{key}' not found")
+
+ file_path = user_dir / f"{key}.owl"
+ if not file_path.exists():
+ raise ValueError(f"Ontology file for key '{key}' not found")
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ contents.append(f.read())
+ return contents
+
+ def list_ontologies(self, user) -> dict:
+ user_dir = self._get_user_dir(str(user.id))
+ return self._load_metadata(user_dir)
diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py
new file mode 100644
index 000000000..ee31c683f
--- /dev/null
+++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py
@@ -0,0 +1,107 @@
+from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
+from fastapi.responses import JSONResponse
+from typing import Optional, List
+
+from cognee.modules.users.models import User
+from cognee.modules.users.methods import get_authenticated_user
+from cognee.shared.utils import send_telemetry
+from cognee import __version__ as cognee_version
+from ..ontologies import OntologyService
+
+
+def get_ontology_router() -> APIRouter:
+ router = APIRouter()
+ ontology_service = OntologyService()
+
+ @router.post("", response_model=dict)
+ async def upload_ontology(
+ ontology_key: str = Form(...),
+ ontology_file: List[UploadFile] = File(...),
+ descriptions: Optional[str] = Form(None),
+ user: User = Depends(get_authenticated_user),
+ ):
+ """
+ Upload ontology files with their respective keys for later use in cognify operations.
+
+ Supports both single and multiple file uploads:
+ - Single file: ontology_key=["key"], ontology_file=[file]
+ - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
+
+ ## Request Parameters
+ - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
+ - **ontology_file** (List[UploadFile]): OWL format ontology files
+ - **descriptions** (Optional[str]): JSON array string of optional descriptions
+
+ ## Response
+ Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
+
+ ## Error Codes
+ - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
+ - **500 Internal Server Error**: File system or processing errors
+ """
+ send_telemetry(
+ "Ontology Upload API Endpoint Invoked",
+ user.id,
+ additional_properties={
+ "endpoint": "POST /api/v1/ontologies",
+ "cognee_version": cognee_version,
+ },
+ )
+
+ try:
+ import json
+
+ ontology_keys = json.loads(ontology_key)
+ description_list = json.loads(descriptions) if descriptions else None
+
+ if not isinstance(ontology_keys, list):
+ raise ValueError("ontology_key must be a JSON array")
+
+ results = await ontology_service.upload_ontologies(
+ ontology_keys, ontology_file, user, description_list
+ )
+
+ return {
+ "uploaded_ontologies": [
+ {
+ "ontology_key": result.ontology_key,
+ "filename": result.filename,
+ "size_bytes": result.size_bytes,
+ "uploaded_at": result.uploaded_at,
+ "description": result.description,
+ }
+ for result in results
+ ]
+ }
+ except (json.JSONDecodeError, ValueError) as e:
+ return JSONResponse(status_code=400, content={"error": str(e)})
+ except Exception as e:
+ return JSONResponse(status_code=500, content={"error": str(e)})
+
+ @router.get("", response_model=dict)
+ async def list_ontologies(user: User = Depends(get_authenticated_user)):
+ """
+ List all uploaded ontologies for the authenticated user.
+
+ ## Response
+ Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
+
+ ## Error Codes
+ - **500 Internal Server Error**: File system or processing errors
+ """
+ send_telemetry(
+ "Ontology List API Endpoint Invoked",
+ user.id,
+ additional_properties={
+ "endpoint": "GET /api/v1/ontologies",
+ "cognee_version": cognee_version,
+ },
+ )
+
+ try:
+ metadata = ontology_service.list_ontologies(user)
+ return metadata
+ except Exception as e:
+ return JSONResponse(status_code=500, content={"error": str(e)})
+
+ return router
diff --git a/cognee/api/v1/permissions/routers/get_permissions_router.py b/cognee/api/v1/permissions/routers/get_permissions_router.py
index 565e95732..63de97eaa 100644
--- a/cognee/api/v1/permissions/routers/get_permissions_router.py
+++ b/cognee/api/v1/permissions/routers/get_permissions_router.py
@@ -1,15 +1,20 @@
from uuid import UUID
-from typing import List
+from typing import List, Union
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
+from cognee.api.DTO import InDTO
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
+class SelectTenantDTO(InDTO):
+ tenant_id: UUID | None = None
+
+
def get_permissions_router() -> APIRouter:
permissions_router = APIRouter()
@@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter:
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
)
+ @permissions_router.post("/tenants/select")
+ async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
+ """
+ Select current tenant.
+
+ This endpoint selects a tenant with the specified UUID. Tenants are used
+ to organize users and resources in multi-tenant environments, providing
+ isolation and access control between different groups or organizations.
+
+ Sending a null/None value as tenant_id selects his default single user tenant
+
+ ## Request Parameters
+ - **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
+
+ ## Response
+ Returns a success message along with selected tenant id.
+ """
+ send_telemetry(
+ "Permissions API Endpoint Invoked",
+ user.id,
+ additional_properties={
+ "endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
+ "tenant_id": str(payload.tenant_id),
+ },
+ )
+
+ from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
+
+ await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
+
+ return JSONResponse(
+ status_code=200,
+ content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
+ )
+
return permissions_router
diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py
index d4e5fbbe6..354331c57 100644
--- a/cognee/api/v1/search/search.py
+++ b/cognee/api/v1/search/search.py
@@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results
diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py
index 16eaf0454..b89c1f70e 100644
--- a/cognee/cli/commands/cognify_command.py
+++ b/cognee/cli/commands/cognify_command.py
@@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
Processing Pipeline:
1. **Document Classification**: Identifies document types and structures
-2. **Permission Validation**: Ensures user has processing rights
+2. **Permission Validation**: Ensures user has processing rights
3. **Text Chunking**: Breaks content into semantically meaningful segments
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
5. **Relationship Detection**: Discovers connections between entities
@@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
chunker_class = LangchainChunker
except ImportError:
fmt.warning("LangchainChunker not available, using TextChunker")
+ elif args.chunker == "CsvChunker":
+ try:
+ from cognee.modules.chunking.CsvChunker import CsvChunker
+
+ chunker_class = CsvChunker
+ except ImportError:
+ fmt.warning("CsvChunker not available, using TextChunker")
result = await cognee.cognify(
datasets=datasets,
diff --git a/cognee/cli/config.py b/cognee/cli/config.py
index d016608c1..082adbaec 100644
--- a/cognee/cli/config.py
+++ b/cognee/cli/config.py
@@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
]
# Chunker choices
-CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
+CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
# Output format choices
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py
index d52de4b4e..62e06fc64 100644
--- a/cognee/context_global_variables.py
+++ b/cognee/context_global_variables.py
@@ -4,6 +4,8 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
+from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
+from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
+VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
+GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
+
async def set_session_user_context_variable(user):
session_user.set(user)
+def multi_user_support_possible():
+ graph_db_config = get_graph_context_config()
+ vector_db_config = get_vectordb_context_config()
+ return (
+ graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
+ and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
+ )
+
+
+def backend_access_control_enabled():
+ backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
+ if backend_access_control is None:
+ # If backend access control is not defined in environment variables,
+ # enable it by default if graph and vector DBs can support it, otherwise disable it
+ return multi_user_support_possible()
+ elif backend_access_control.lower() == "true":
+ # If enabled, ensure that the current graph and vector DBs can support it
+ multi_user_support = multi_user_support_possible()
+ if not multi_user_support:
+ raise EnvironmentError(
+ "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
+ )
+ return True
+ return False
+
+
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,
@@ -38,9 +69,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"""
- base_config = get_base_config()
-
- if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
+ if not backend_access_control_enabled():
return
user = await get_user(user_id)
@@ -48,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
+ base_config = get_base_config()
data_root_directory = os.path.join(
base_config.data_root_directory, str(user.tenant_id or user.id)
)
@@ -57,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# Set vector and graph database configuration based on dataset database information
vector_config = {
- "vector_db_url": os.path.join(
- databases_directory_path, dataset_database.vector_database_name
- ),
- "vector_db_key": "",
- "vector_db_provider": "lancedb",
+ "vector_db_provider": dataset_database.vector_database_provider,
+ "vector_db_url": dataset_database.vector_database_url,
+ "vector_db_key": dataset_database.vector_database_key,
+ "vector_db_name": dataset_database.vector_database_name,
}
graph_config = {
- "graph_database_provider": "kuzu",
+ "graph_database_provider": dataset_database.graph_database_provider,
+ "graph_database_url": dataset_database.graph_database_url,
+ "graph_database_name": dataset_database.graph_database_name,
+ "graph_database_key": dataset_database.graph_database_key,
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
),
diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile
new file mode 100644
index 000000000..e83be3da4
--- /dev/null
+++ b/cognee/eval_framework/Dockerfile
@@ -0,0 +1,29 @@
+FROM python:3.11-slim
+
+# Set environment variables
+ENV PIP_NO_CACHE_DIR=true
+ENV PATH="${PATH}:/root/.poetry/bin"
+ENV PYTHONPATH=/app
+ENV SKIP_MIGRATIONS=true
+
+# System dependencies
+RUN apt-get update && apt-get install -y \
+ gcc \
+ libpq-dev \
+ git \
+ curl \
+ build-essential \
+ && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /app
+
+COPY pyproject.toml poetry.lock README.md /app/
+
+RUN pip install poetry
+
+RUN poetry config virtualenvs.create false
+
+RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
+
+COPY cognee/ /app/cognee
+COPY distributed/ /app/distributed
diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 6f166657e..29b3ede68 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
+ ############
+ #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
+ if isinstance(retrieval_context, list):
+ retrieval_context = await retriever.convert_retrieved_objects_to_context(
+ triplets=retrieval_context
+ )
+
+ if isinstance(search_results, str):
+ search_results = [search_results]
+ #############
answer = {
"question": query_text,
"answer": search_results[0],
diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py
index d0a2ebe1e..6b55d84b2 100644
--- a/cognee/eval_framework/answer_generation/run_question_answering_module.py
+++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py
@@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
- params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
+ params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")
diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py
index 6edcc0454..9e6f26688 100644
--- a/cognee/eval_framework/eval_config.py
+++ b/cognee/eval_framework/eval_config.py
@@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
- qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
+ qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
- deepeval_model: str = "gpt-5-mini"
+ deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True
diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py
index aca2686a5..bc2ff77c5 100644
--- a/cognee/eval_framework/modal_run_eval.py
+++ b/cognee/eval_framework/modal_run_eval.py
@@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
-import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
+import pathlib
+from os import path
+from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
-image = (
- modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
- .copy_local_file("pyproject.toml", "pyproject.toml")
- .copy_local_file("poetry.lock", "poetry.lock")
- .env(
- {
- "ENV": os.getenv("ENV"),
- "LLM_API_KEY": os.getenv("LLM_API_KEY"),
- "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
- }
- )
- .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
+image = Image.from_dockerfile(
+ path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
+ force_build=False,
+).add_local_python_source("cognee")
+
+
+@app.function(
+ image=image,
+ max_containers=10,
+ timeout=86400,
+ volumes={"/data": vol},
+ secrets=[modal.Secret.from_name("eval_secrets")],
)
-
-
-@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
- benchmark="HotPotQA",
- qa_engine="cognee_graph_completion",
- building_corpus_from_scratch=True,
- answering_questions=True,
- evaluating_answers=True,
- calculate_metrics=True,
- dashboard=True,
- ),
- EvalConfig(
- task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
diff --git a/cognee/infrastructure/databases/cache/config.py b/cognee/infrastructure/databases/cache/config.py
index 3a28827fe..88ac05885 100644
--- a/cognee/infrastructure/databases/cache/config.py
+++ b/cognee/infrastructure/databases/cache/config.py
@@ -1,6 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from functools import lru_cache
-from typing import Optional
+from typing import Optional, Literal
class CacheConfig(BaseSettings):
@@ -15,6 +15,7 @@ class CacheConfig(BaseSettings):
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
"""
+ cache_backend: Literal["redis", "fs"] = "fs"
caching: bool = False
shared_kuzu_lock: bool = False
cache_host: str = "localhost"
@@ -28,6 +29,7 @@ class CacheConfig(BaseSettings):
def to_dict(self) -> dict:
return {
+ "cache_backend": self.cache_backend,
"caching": self.caching,
"shared_kuzu_lock": self.shared_kuzu_lock,
"cache_host": self.cache_host,
diff --git a/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py
new file mode 100644
index 000000000..497e6afec
--- /dev/null
+++ b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py
@@ -0,0 +1,151 @@
+import asyncio
+import json
+import os
+from datetime import datetime
+import time
+import threading
+import diskcache as dc
+
+from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
+from cognee.infrastructure.databases.exceptions.exceptions import (
+ CacheConnectionError,
+ SharedKuzuLockRequiresRedisError,
+)
+from cognee.infrastructure.files.storage.get_storage_config import get_storage_config
+from cognee.shared.logging_utils import get_logger
+
+logger = get_logger("FSCacheAdapter")
+
+
+class FSCacheAdapter(CacheDBInterface):
+ def __init__(self):
+ default_key = "sessions_db"
+
+ storage_config = get_storage_config()
+ data_root_directory = storage_config["data_root_directory"]
+ cache_directory = os.path.join(data_root_directory, ".cognee_fs_cache", default_key)
+ os.makedirs(cache_directory, exist_ok=True)
+ self.cache = dc.Cache(directory=cache_directory)
+ self.cache.expire()
+
+ logger.debug(f"FSCacheAdapter initialized with cache directory: {cache_directory}")
+
+ def acquire_lock(self):
+ """Lock acquisition is not available for filesystem cache backend."""
+ message = "Shared Kuzu lock requires Redis cache backend."
+ logger.error(message)
+ raise SharedKuzuLockRequiresRedisError()
+
+ def release_lock(self):
+ """Lock release is not available for filesystem cache backend."""
+ message = "Shared Kuzu lock requires Redis cache backend."
+ logger.error(message)
+ raise SharedKuzuLockRequiresRedisError()
+
+ async def add_qa(
+ self,
+ user_id: str,
+ session_id: str,
+ question: str,
+ context: str,
+ answer: str,
+ ttl: int | None = 86400,
+ ):
+ try:
+ session_key = f"agent_sessions:{user_id}:{session_id}"
+
+ qa_entry = {
+ "time": datetime.utcnow().isoformat(),
+ "question": question,
+ "context": context,
+ "answer": answer,
+ }
+
+ existing_value = self.cache.get(session_key)
+ if existing_value is not None:
+ value: list = json.loads(existing_value)
+ value.append(qa_entry)
+ else:
+ value = [qa_entry]
+
+ self.cache.set(session_key, json.dumps(value), expire=ttl)
+ except Exception as e:
+ error_msg = f"Unexpected error while adding Q&A to diskcache: {str(e)}"
+ logger.error(error_msg)
+ raise CacheConnectionError(error_msg) from e
+
+ async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
+ session_key = f"agent_sessions:{user_id}:{session_id}"
+ value = self.cache.get(session_key)
+ if value is None:
+ return None
+ entries = json.loads(value)
+ return entries[-last_n:] if len(entries) > last_n else entries
+
+ async def get_all_qas(self, user_id: str, session_id: str):
+ session_key = f"agent_sessions:{user_id}:{session_id}"
+ value = self.cache.get(session_key)
+ if value is None:
+ return None
+ return json.loads(value)
+
+ async def close(self):
+ if self.cache is not None:
+ self.cache.expire()
+ self.cache.close()
+
+
+async def main():
+ adapter = FSCacheAdapter()
+ session_id = "demo_session"
+ user_id = "demo_user_id"
+
+ print("\nAdding sample Q/A pairs...")
+ await adapter.add_qa(
+ user_id,
+ session_id,
+ "What is Redis?",
+ "Basic DB context",
+ "Redis is an in-memory data store.",
+ )
+ await adapter.add_qa(
+ user_id,
+ session_id,
+ "Who created Redis?",
+ "Historical context",
+ "Salvatore Sanfilippo (antirez).",
+ )
+
+ print("\nLatest QA:")
+ latest = await adapter.get_latest_qa(user_id, session_id)
+ print(json.dumps(latest, indent=2))
+
+ print("\nLast 2 QAs:")
+ last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
+ print(json.dumps(last_two, indent=2))
+
+ session_id = "session_expire_demo"
+
+ await adapter.add_qa(
+ user_id,
+ session_id,
+ "What is Redis?",
+ "Database context",
+ "Redis is an in-memory data store.",
+ )
+
+ await adapter.add_qa(
+ user_id,
+ session_id,
+ "Who created Redis?",
+ "History context",
+ "Salvatore Sanfilippo (antirez).",
+ )
+
+ print(await adapter.get_all_qas(user_id, session_id))
+
+ await adapter.close()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/cognee/infrastructure/databases/cache/get_cache_engine.py b/cognee/infrastructure/databases/cache/get_cache_engine.py
index c1fa3311c..f70358607 100644
--- a/cognee/infrastructure/databases/cache/get_cache_engine.py
+++ b/cognee/infrastructure/databases/cache/get_cache_engine.py
@@ -1,9 +1,11 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
+import os
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
+from cognee.infrastructure.databases.cache.fscache.FsCacheAdapter import FSCacheAdapter
config = get_cache_config()
@@ -33,20 +35,28 @@ def create_cache_engine(
Returns:
--------
- - CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic
+ - CacheDBInterface: An instance of the appropriate cache adapter.
"""
if config.caching:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
- return RedisAdapter(
- host=cache_host,
- port=cache_port,
- username=cache_username,
- password=cache_password,
- lock_name=lock_key,
- timeout=agentic_lock_expire,
- blocking_timeout=agentic_lock_timeout,
- )
+ if config.cache_backend == "redis":
+ return RedisAdapter(
+ host=cache_host,
+ port=cache_port,
+ username=cache_username,
+ password=cache_password,
+ lock_name=lock_key,
+ timeout=agentic_lock_expire,
+ blocking_timeout=agentic_lock_timeout,
+ )
+ elif config.cache_backend == "fs":
+ return FSCacheAdapter()
+ else:
+ raise ValueError(
+ f"Unsupported cache backend: '{config.cache_backend}'. "
+ f"Supported backends are: 'redis', 'fs'"
+ )
else:
return None
diff --git a/cognee/infrastructure/databases/exceptions/exceptions.py b/cognee/infrastructure/databases/exceptions/exceptions.py
index 72b13e3a2..d8dd99c17 100644
--- a/cognee/infrastructure/databases/exceptions/exceptions.py
+++ b/cognee/infrastructure/databases/exceptions/exceptions.py
@@ -148,3 +148,19 @@ class CacheConnectionError(CogneeConfigurationError):
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
):
super().__init__(message, name, status_code)
+
+
+class SharedKuzuLockRequiresRedisError(CogneeConfigurationError):
+ """
+ Raised when shared Kuzu locking is requested without configuring the Redis backend.
+ """
+
+ def __init__(
+ self,
+ message: str = (
+ "Shared Kuzu lock requires Redis cache backend. Configure Redis to enable shared Kuzu locking."
+ ),
+ name: str = "SharedKuzuLockRequiresRedisError",
+ status_code: int = status.HTTP_400_BAD_REQUEST,
+ ):
+ super().__init__(message, name, status_code)
diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py
index b7907313c..23687b359 100644
--- a/cognee/infrastructure/databases/graph/config.py
+++ b/cognee/infrastructure/databases/graph/config.py
@@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
- graph_database_username
- graph_database_password
- graph_database_port
+ - graph_database_key
- graph_file_path
- graph_model
- graph_topology
@@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
graph_database_username: str = ""
graph_database_password: str = ""
graph_database_port: int = 123
+ graph_database_key: str = ""
graph_file_path: str = ""
graph_filename: str = ""
graph_model: object = KnowledgeGraph
@@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
+ "graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
@@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
+ "graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
}
diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py
index 1ea61d29f..82e3cad6e 100644
--- a/cognee/infrastructure/databases/graph/get_graph_engine.py
+++ b/cognee/infrastructure/databases/graph/get_graph_engine.py
@@ -33,6 +33,7 @@ def create_graph_engine(
graph_database_username="",
graph_database_password="",
graph_database_port="",
+ graph_database_key="",
):
"""
Create a graph engine based on the specified provider type.
@@ -69,6 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
+ database_name=graph_database_name,
)
if graph_database_provider == "neo4j":
diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py
index 67df1a27c..8f8c96e79 100644
--- a/cognee/infrastructure/databases/graph/graph_db_interface.py
+++ b/cognee/infrastructure/databases/graph/graph_db_interface.py
@@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
+
+ @abstractmethod
+ async def get_filtered_graph_data(
+ self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
+ ) -> Tuple[List[Node], List[EdgeData]]:
+ """
+ Retrieve nodes and edges filtered by the provided attribute criteria.
+
+ Parameters:
+ -----------
+
+ - attribute_filters: A list of dictionaries where keys are attribute names and values
+ are lists of attribute values to filter by.
+ """
+ raise NotImplementedError
diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py
index 8dd160665..9dbc9c1bc 100644
--- a/cognee/infrastructure/databases/graph/kuzu/adapter.py
+++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
+from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
+
+ import time
+
+ start_time = time.time()
+
try:
nodes_query = """
MATCH (n:Node)
@@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
+ )
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ Returns:
+ nodes: List[dict] -> Each dict includes "id" and all node properties
+ edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ if not all(isinstance(x, str) for x in target_ids):
+ raise CogneeValidationError("target_ids must be a list of strings")
+
+ query = """
+ MATCH (n:Node)-[r]->(m:Node)
+ WHERE n.id IN $target_ids OR m.id IN $target_ids
+ RETURN n.id, {
+ name: n.name,
+ type: n.type,
+ properties: n.properties
+ }, m.id, {
+ name: m.name,
+ type: m.type,
+ properties: m.properties
+ }, r.relationship_name, r.properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ if not result:
+ logger.info("No data returned for the supplied IDs")
+ return [], []
+
+ nodes_dict = {}
+ edges = []
+
+ for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
+ if n_props.get("properties"):
+ try:
+ additional_props = json.loads(n_props["properties"])
+ n_props.update(additional_props)
+ del n_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {n_id}")
+
+ if m_props.get("properties"):
+ try:
+ additional_props = json.loads(m_props["properties"])
+ m_props.update(additional_props)
+ del m_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {m_id}")
+
+ nodes_dict[n_id] = (n_id, n_props)
+ nodes_dict[m_id] = (m_id, m_props)
+
+ edge_props = {}
+ if r_props_raw:
+ try:
+ edge_props = json.loads(r_props_raw)
+ except (json.JSONDecodeError, TypeError):
+ logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
+
+ source_id = edge_props.get("source_node_id", n_id)
+ target_id = edge_props.get("target_node_id", m_id)
+ edges.append((source_id, target_id, r_type, edge_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
index 6216e107e..f3bb8e173 100644
--- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
+++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
@@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ This version uses a single Cypher query for efficiency.
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ query = """
+ MATCH ()-[r]-()
+ WHERE startNode(r).id IN $target_ids
+ OR endNode(r).id IN $target_ids
+ WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
+ RETURN
+ properties(a) AS n_properties,
+ properties(b) AS m_properties,
+ type(r) AS type,
+ properties(r) AS properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ nodes_dict = {}
+ edges = []
+
+ for record in result:
+ n_props = record["n_properties"]
+ m_props = record["m_properties"]
+ r_props = record["properties"]
+ r_type = record["type"]
+
+ nodes_dict[n_props["id"]] = (n_props["id"], n_props)
+ nodes_dict[m_props["id"]] = (m_props["id"], m_props)
+
+ source_id = r_props.get("source_node_id", n_props["id"])
+ target_id = r_props.get("target_node_id", m_props["id"])
+ edges.append((source_id, target_id, r_type, r_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py
index 5357f3d7c..1e16642b5 100644
--- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py
+++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py
@@ -416,6 +416,15 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
pass
+ async def is_empty(self) -> bool:
+ query = """
+ MATCH (n)
+ RETURN true
+ LIMIT 1;
+ """
+ query_result = await self._client.query(query)
+ return len(query_result) == 0
+
@staticmethod
def _get_scored_result(
item: dict, with_vector: bool = False, with_score: bool = False
diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py
index 29156025d..3684bb100 100644
--- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py
+++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py
@@ -1,11 +1,15 @@
+import os
from uuid import UUID
from typing import Union
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
-from cognee.modules.data.methods import create_dataset
+from cognee.base_config import get_base_config
+from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
+from cognee.infrastructure.databases.vector import get_vectordb_config
+from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
@@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
- vector_db_name = f"{dataset_id}.lance.db"
- graph_db_name = f"{dataset_id}.pkl"
+ vector_config = get_vectordb_config()
+ graph_config = get_graph_config()
+
+ # Note: for hybrid databases both graph and vector DB name have to be the same
+ if graph_config.graph_database_provider == "kuzu":
+ graph_db_name = f"{dataset_id}.pkl"
+ else:
+ graph_db_name = f"{dataset_id}"
+
+ if vector_config.vector_db_provider == "lancedb":
+ vector_db_name = f"{dataset_id}.lance.db"
+ else:
+ vector_db_name = f"{dataset_id}"
+
+ base_config = get_base_config()
+ databases_directory_path = os.path.join(
+ base_config.system_root_directory, "databases", str(user.id)
+ )
+
+ # Determine vector database URL
+ if vector_config.vector_db_provider == "lancedb":
+ vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
+ else:
+ vector_db_url = vector_config.vector_database_url
+
+ # Determine graph database URL
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
@@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
+ vector_database_provider=vector_config.vector_db_provider,
+ graph_database_provider=graph_config.graph_database_provider,
+ vector_database_url=vector_db_url,
+ graph_database_url=graph_config.graph_database_url,
+ vector_database_key=vector_config.vector_db_key,
+ graph_database_key=graph_config.graph_database_key,
)
try:
diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py
index b6d3ae644..7d28f1668 100644
--- a/cognee/infrastructure/databases/vector/config.py
+++ b/cognee/infrastructure/databases/vector/config.py
@@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables:
- vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database.
+ - vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database.
"""
vector_db_url: str = ""
vector_db_port: int = 1234
+ vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
@@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
+ "vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}
diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py
index d1cf855d7..b182f084b 100644
--- a/cognee/infrastructure/databases/vector/create_vector_engine.py
+++ b/cognee/infrastructure/databases/vector/create_vector_engine.py
@@ -1,5 +1,6 @@
from .supported_databases import supported_databases
from .embeddings import get_embedding_engine
+from cognee.infrastructure.databases.graph.config import get_graph_context_config
from functools import lru_cache
@@ -8,6 +9,7 @@ from functools import lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
+ vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
):
@@ -27,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some
providers.
+ - vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector').
@@ -45,6 +48,7 @@ def create_vector_engine(
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
+ database_name=vector_db_name,
)
if vector_db_provider.lower() == "pgvector":
@@ -133,6 +137,6 @@ def create_vector_engine(
else:
raise EnvironmentError(
- f"Unsupported graph database provider: {vector_db_provider}. "
+ f"Unsupported vector database provider: {vector_db_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
)
diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py
index 5ad9c84dd..59f01a9ab 100644
--- a/cognee/infrastructure/engine/models/Edge.py
+++ b/cognee/infrastructure/engine/models/Edge.py
@@ -1,4 +1,4 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict
@@ -18,9 +18,21 @@ class Edge(BaseModel):
# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
+
+ # With edge_text for rich embedding representation
+ contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""
weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
+ edge_text: Optional[str] = None
+
+ @field_validator("edge_text", mode="before")
+ @classmethod
+ def ensure_edge_text(cls, v, info):
+ """Auto-populate edge_text from relationship_type if not explicitly provided."""
+ if v is None and info.data.get("relationship_type"):
+ return info.data["relationship_type"]
+ return v
diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py
index 78b20c93d..4bc96fe80 100644
--- a/cognee/infrastructure/files/utils/guess_file_type.py
+++ b/cognee/infrastructure/files/utils/guess_file_type.py
@@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
file_type = Type("text/plain", "txt")
return file_type
+ if ext in [".csv"]:
+ file_type = Type("text/csv", "csv")
+ return file_type
+
file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py
index 8fd196eaf..2e300dc0c 100644
--- a/cognee/infrastructure/llm/config.py
+++ b/cognee/infrastructure/llm/config.py
@@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
"""
structured_output_framework: str = "instructor"
+ llm_instructor_mode: str = ""
llm_provider: str = "openai"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
@@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
instance.
"""
return {
+ "llm_instructor_mode": self.llm_instructor_mode.lower(),
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
index bf19d6e86..dbf0dfbea 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
@@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic"
model: str
+ default_instructor_mode = "anthropic_tools"
- def __init__(self, max_completion_tokens: int, model: str = None):
+ def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
- mode=instructor.Mode.ANTHROPIC_TOOLS,
+ mode=instructor.Mode(self.instructor_mode),
)
self.model = model
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
index 1187e0cad..226f291d7 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
@@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
name: str
model: str
api_key: str
+ default_instructor_mode = "json_mode"
def __init__(
self,
@@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
model: str,
api_version: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
- self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
+ self.aclient = instructor.from_litellm(
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
+ )
@retry(
stop=stop_after_delay(128),
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
index 8bbbaa2cc..9d7f25fc5 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
@@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
name: str
model: str
api_key: str
+ default_instructor_mode = "json_mode"
def __init__(
self,
@@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
model: str,
name: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
- self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
+ self.aclient = instructor.from_litellm(
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
+ )
@retry(
stop=stop_after_delay(128),
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
index c7dcecc56..39558f36d 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
@@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.ANTHROPIC:
@@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
- max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
+ max_completion_tokens=max_completion_tokens,
+ model=llm_config.llm_model,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.CUSTOM:
@@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Custom",
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.MISTRAL:
@@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
- )
-
- elif provider == LLMProvider.MISTRAL:
- if llm_config.llm_api_key is None:
- raise LLMAPIKeyNotSetError()
-
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
- MistralAdapter,
- )
-
- return MistralAdapter(
- api_key=llm_config.llm_api_key,
- model=llm_config.llm_model,
- max_completion_tokens=max_completion_tokens,
- endpoint=llm_config.llm_endpoint,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else:
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
index 78a3cbff5..355cdae0b 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
@@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
model: str
api_key: str
max_completion_tokens: int
+ default_instructor_mode = "mistral_tools"
- def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
+ def __init__(
+ self,
+ api_key: str,
+ model: str,
+ max_completion_tokens: int,
+ endpoint: str = None,
+ instructor_mode: str = None,
+ ):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.from_litellm(
litellm.acompletion,
- mode=instructor.Mode.MISTRAL_TOOLS,
+ mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
index 9c3d185aa..aabd19867 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
@@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
- aclient
"""
+ default_instructor_mode = "json_mode"
+
def __init__(
- self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
+ self,
+ endpoint: str,
+ api_key: str,
+ model: str,
+ name: str,
+ max_completion_tokens: int,
+ instructor_mode: str = None,
):
self.name = name
self.model = model
@@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.from_openai(
- OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
+ OpenAI(base_url=self.endpoint, api_key=self.api_key),
+ mode=instructor.Mode(self.instructor_mode),
)
@retry(
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
index 305b426b8..778c8eec7 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
@@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
model: str
api_key: str
api_version: str
+ default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
@@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
model: str,
transcription_model: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
- litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(
- litellm.completion, mode=instructor.Mode.JSON_SCHEMA
+ litellm.completion, mode=instructor.Mode(self.instructor_mode)
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py
index f9511e7c5..4a363a0e6 100644
--- a/cognee/infrastructure/loaders/LoaderEngine.py
+++ b/cognee/infrastructure/loaders/LoaderEngine.py
@@ -31,6 +31,7 @@ class LoaderEngine:
"pypdf_loader",
"image_loader",
"audio_loader",
+ "csv_loader",
"unstructured_loader",
"advanced_pdf_loader",
]
diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py
index 8a2df80f9..09819fbd2 100644
--- a/cognee/infrastructure/loaders/core/__init__.py
+++ b/cognee/infrastructure/loaders/core/__init__.py
@@ -3,5 +3,6 @@
from .text_loader import TextLoader
from .audio_loader import AudioLoader
from .image_loader import ImageLoader
+from .csv_loader import CsvLoader
-__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
+__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py
new file mode 100644
index 000000000..a314a7a24
--- /dev/null
+++ b/cognee/infrastructure/loaders/core/csv_loader.py
@@ -0,0 +1,93 @@
+import os
+from typing import List
+import csv
+from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
+from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
+from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
+
+
+class CsvLoader(LoaderInterface):
+ """
+ Core CSV file loader that handles basic CSV file formats.
+ """
+
+ @property
+ def supported_extensions(self) -> List[str]:
+ """Supported text file extensions."""
+ return [
+ "csv",
+ ]
+
+ @property
+ def supported_mime_types(self) -> List[str]:
+ """Supported MIME types for text content."""
+ return [
+ "text/csv",
+ ]
+
+ @property
+ def loader_name(self) -> str:
+ """Unique identifier for this loader."""
+ return "csv_loader"
+
+ def can_handle(self, extension: str, mime_type: str) -> bool:
+ """
+ Check if this loader can handle the given file.
+
+ Args:
+ extension: File extension
+ mime_type: Optional MIME type
+
+ Returns:
+ True if file can be handled, False otherwise
+ """
+ if extension in self.supported_extensions and mime_type in self.supported_mime_types:
+ return True
+
+ return False
+
+ async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
+ """
+ Load and process the csv file.
+
+ Args:
+ file_path: Path to the file to load
+ encoding: Text encoding to use (default: utf-8)
+ **kwargs: Additional configuration (unused)
+
+ Returns:
+ LoaderResult containing the file content and metadata
+
+ Raises:
+ FileNotFoundError: If file doesn't exist
+ UnicodeDecodeError: If file cannot be decoded with specified encoding
+ OSError: If file cannot be read
+ """
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"File not found: {file_path}")
+
+ with open(file_path, "rb") as f:
+ file_metadata = await get_file_metadata(f)
+ # Name ingested file of current loader based on original file content hash
+ storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
+
+ row_texts = []
+ row_index = 1
+
+ with open(file_path, "r", encoding=encoding, newline="") as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
+ row_text = ", ".join(pairs)
+ row_texts.append(f"Row {row_index}:\n{row_text}\n")
+ row_index += 1
+
+ content = "\n".join(row_texts)
+
+ storage_config = get_storage_config()
+ data_root_directory = storage_config["data_root_directory"]
+ storage = get_file_storage(data_root_directory)
+
+ full_file_path = await storage.store(storage_file_name, content)
+
+ return full_file_path
diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py
index a6f94be9b..e478edb22 100644
--- a/cognee/infrastructure/loaders/core/text_loader.py
+++ b/cognee/infrastructure/loaders/core/text_loader.py
@@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
- return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
+ return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
@property
def supported_mime_types(self) -> List[str]:
@@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
return [
"text/plain",
"text/markdown",
- "text/csv",
"application/json",
"text/xml",
"application/xml",
diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
index 6d1412b77..4b3ba296a 100644
--- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
+++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
@@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
-
-
-if __name__ == "__main__":
- loader = AdvancedPdfLoader()
- asyncio.run(
- loader.load(
- "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
- )
- )
diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py
index 156253b53..2b8c3e0b4 100644
--- a/cognee/infrastructure/loaders/supported_loaders.py
+++ b/cognee/infrastructure/loaders/supported_loaders.py
@@ -1,5 +1,5 @@
from cognee.infrastructure.loaders.external import PyPdfLoader
-from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
+from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
# Registry for loader implementations
supported_loaders = {
@@ -7,6 +7,7 @@ supported_loaders = {
TextLoader.loader_name: TextLoader,
ImageLoader.loader_name: ImageLoader,
AudioLoader.loader_name: AudioLoader,
+ CsvLoader.loader_name: CsvLoader,
}
# Try adding optional loaders
diff --git a/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py
new file mode 100644
index 000000000..92d64c156
--- /dev/null
+++ b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py
@@ -0,0 +1,55 @@
+from typing import Optional, List
+
+from cognee import memify
+from cognee.context_global_variables import (
+ set_database_global_context_variables,
+ set_session_user_context_variable,
+)
+from cognee.exceptions import CogneeValidationError
+from cognee.modules.data.methods import get_authorized_existing_datasets
+from cognee.shared.logging_utils import get_logger
+from cognee.modules.pipelines.tasks.task import Task
+from cognee.modules.users.models import User
+from cognee.tasks.memify import extract_user_sessions, cognify_session
+
+
+logger = get_logger("persist_sessions_in_knowledge_graph")
+
+
+async def persist_sessions_in_knowledge_graph_pipeline(
+ user: User,
+ session_ids: Optional[List[str]] = None,
+ dataset: str = "main_dataset",
+ run_in_background: bool = False,
+):
+ await set_session_user_context_variable(user)
+ dataset_to_write = await get_authorized_existing_datasets(
+ user=user, datasets=[dataset], permission_type="write"
+ )
+
+ if not dataset_to_write:
+ raise CogneeValidationError(
+ message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}",
+ log=False,
+ )
+
+ await set_database_global_context_variables(
+ dataset_to_write[0].id, dataset_to_write[0].owner_id
+ )
+
+ extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
+
+ enrichment_tasks = [
+ Task(cognify_session, dataset_id=dataset_to_write[0].id),
+ ]
+
+ result = await memify(
+ extraction_tasks=extraction_tasks,
+ enrichment_tasks=enrichment_tasks,
+ dataset=dataset_to_write[0].id,
+ data=[{}],
+ run_in_background=run_in_background,
+ )
+
+ logger.info("Session persistence pipeline completed")
+ return result
diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py
new file mode 100644
index 000000000..4ba4a969e
--- /dev/null
+++ b/cognee/modules/chunking/CsvChunker.py
@@ -0,0 +1,35 @@
+from cognee.shared.logging_utils import get_logger
+
+
+from cognee.tasks.chunks import chunk_by_row
+from cognee.modules.chunking.Chunker import Chunker
+from .models.DocumentChunk import DocumentChunk
+
+logger = get_logger()
+
+
+class CsvChunker(Chunker):
+ async def read(self):
+ async for content_text in self.get_text():
+ if content_text is None:
+ continue
+
+ for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
+ if chunk_data["chunk_size"] <= self.max_chunk_size:
+ yield DocumentChunk(
+ id=chunk_data["chunk_id"],
+ text=chunk_data["text"],
+ chunk_size=chunk_data["chunk_size"],
+ is_part_of=self.document,
+ chunk_index=self.chunk_index,
+ cut_type=chunk_data["cut_type"],
+ contains=[],
+ metadata={
+ "index_fields": ["text"],
+ },
+ )
+ self.chunk_index += 1
+ else:
+ raise ValueError(
+ f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
+ )
diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py
index 9f8c57486..e024bf00b 100644
--- a/cognee/modules/chunking/models/DocumentChunk.py
+++ b/cognee/modules/chunking/models/DocumentChunk.py
@@ -1,6 +1,7 @@
from typing import List, Union
from cognee.infrastructure.engine import DataPoint
+from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
@@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
- contains: List[Union[Entity, Event]] = None
+ contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
metadata: dict = {"index_fields": ["text"]}
diff --git a/cognee/modules/chunking/text_chunker_with_overlap.py b/cognee/modules/chunking/text_chunker_with_overlap.py
new file mode 100644
index 000000000..4b9c23079
--- /dev/null
+++ b/cognee/modules/chunking/text_chunker_with_overlap.py
@@ -0,0 +1,124 @@
+from cognee.shared.logging_utils import get_logger
+from uuid import NAMESPACE_OID, uuid5
+
+from cognee.tasks.chunks import chunk_by_paragraph
+from cognee.modules.chunking.Chunker import Chunker
+from .models.DocumentChunk import DocumentChunk
+
+logger = get_logger()
+
+
+class TextChunkerWithOverlap(Chunker):
+ def __init__(
+ self,
+ document,
+ get_text: callable,
+ max_chunk_size: int,
+ chunk_overlap_ratio: float = 0.0,
+ get_chunk_data: callable = None,
+ ):
+ super().__init__(document, get_text, max_chunk_size)
+ self._accumulated_chunk_data = []
+ self._accumulated_size = 0
+ self.chunk_overlap_ratio = chunk_overlap_ratio
+ self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
+
+ if get_chunk_data is not None:
+ self.get_chunk_data = get_chunk_data
+ elif chunk_overlap_ratio > 0:
+ paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
+ self.get_chunk_data = lambda text: chunk_by_paragraph(
+ text, paragraph_max_size, batch_paragraphs=True
+ )
+ else:
+ self.get_chunk_data = lambda text: chunk_by_paragraph(
+ text, self.max_chunk_size, batch_paragraphs=True
+ )
+
+ def _accumulation_overflows(self, chunk_data):
+ """Check if adding chunk_data would exceed max_chunk_size."""
+ return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
+
+ def _accumulate_chunk_data(self, chunk_data):
+ """Add chunk_data to the current accumulation."""
+ self._accumulated_chunk_data.append(chunk_data)
+ self._accumulated_size += chunk_data["chunk_size"]
+
+ def _clear_accumulation(self):
+ """Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
+ if self.chunk_overlap == 0:
+ self._accumulated_chunk_data = []
+ self._accumulated_size = 0
+ return
+
+ # Keep chunk_data from the end that fit in overlap
+ overlap_chunk_data = []
+ overlap_size = 0
+
+ for chunk_data in reversed(self._accumulated_chunk_data):
+ if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
+ overlap_chunk_data.insert(0, chunk_data)
+ overlap_size += chunk_data["chunk_size"]
+ else:
+ break
+
+ self._accumulated_chunk_data = overlap_chunk_data
+ self._accumulated_size = overlap_size
+
+ def _create_chunk(self, text, size, cut_type, chunk_id=None):
+ """Create a DocumentChunk with standard metadata."""
+ try:
+ return DocumentChunk(
+ id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
+ text=text,
+ chunk_size=size,
+ is_part_of=self.document,
+ chunk_index=self.chunk_index,
+ cut_type=cut_type,
+ contains=[],
+ metadata={"index_fields": ["text"]},
+ )
+ except Exception as e:
+ logger.error(e)
+ raise e
+
+ def _create_chunk_from_accumulation(self):
+ """Create a DocumentChunk from current accumulated chunk_data."""
+ chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
+ return self._create_chunk(
+ text=chunk_text,
+ size=self._accumulated_size,
+ cut_type=self._accumulated_chunk_data[-1]["cut_type"],
+ )
+
+ def _emit_chunk(self, chunk_data):
+ """Emit a chunk when accumulation overflows."""
+ if len(self._accumulated_chunk_data) > 0:
+ chunk = self._create_chunk_from_accumulation()
+ self._clear_accumulation()
+ self._accumulate_chunk_data(chunk_data)
+ else:
+ # Handle single chunk_data exceeding max_chunk_size
+ chunk = self._create_chunk(
+ text=chunk_data["text"],
+ size=chunk_data["chunk_size"],
+ cut_type=chunk_data["cut_type"],
+ chunk_id=chunk_data["chunk_id"],
+ )
+
+ self.chunk_index += 1
+ return chunk
+
+ async def read(self):
+ async for content_text in self.get_text():
+ for chunk_data in self.get_chunk_data(content_text):
+ if not self._accumulation_overflows(chunk_data):
+ self._accumulate_chunk_data(chunk_data)
+ continue
+
+ yield self._emit_chunk(chunk_data)
+
+ if len(self._accumulated_chunk_data) == 0:
+ return
+
+ yield self._create_chunk_from_accumulation()
diff --git a/cognee/modules/data/methods/__init__.py b/cognee/modules/data/methods/__init__.py
index 83913085c..7936a9afd 100644
--- a/cognee/modules/data/methods/__init__.py
+++ b/cognee/modules/data/methods/__init__.py
@@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
from .get_data import get_data
from .get_unique_dataset_id import get_unique_dataset_id
+from .get_unique_data_id import get_unique_data_id
from .get_authorized_existing_datasets import get_authorized_existing_datasets
from .get_dataset_ids import get_dataset_ids
diff --git a/cognee/modules/data/methods/create_dataset.py b/cognee/modules/data/methods/create_dataset.py
index c080de0e8..7e28a8255 100644
--- a/cognee/modules/data/methods/create_dataset.py
+++ b/cognee/modules/data/methods/create_dataset.py
@@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
.options(joinedload(Dataset.data))
.filter(Dataset.name == dataset_name)
.filter(Dataset.owner_id == owner_id)
+ .filter(Dataset.tenant_id == user.tenant_id)
)
).first()
if dataset is None:
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
- dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
- dataset.owner_id = owner_id
+ dataset = Dataset(
+ id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
+ )
session.add(dataset)
diff --git a/cognee/modules/data/methods/get_dataset_ids.py b/cognee/modules/data/methods/get_dataset_ids.py
index d4402ff36..a61e85310 100644
--- a/cognee/modules/data/methods/get_dataset_ids.py
+++ b/cognee/modules/data/methods/get_dataset_ids.py
@@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user):
# Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.)
user_datasets = await get_datasets(user.id)
# Filter out non name mentioned datasets
- dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets]
+ dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets]
+ # Filter out non current tenant datasets
+ dataset_ids = [
+ dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id
+ ]
else:
raise DatasetTypeError(
f"One or more of the provided dataset types is not handled: f{datasets}"
diff --git a/cognee/modules/data/methods/get_unique_data_id.py b/cognee/modules/data/methods/get_unique_data_id.py
new file mode 100644
index 000000000..877b5930c
--- /dev/null
+++ b/cognee/modules/data/methods/get_unique_data_id.py
@@ -0,0 +1,68 @@
+from uuid import uuid5, NAMESPACE_OID, UUID
+from sqlalchemy import select
+
+from cognee.modules.data.models.Data import Data
+from cognee.infrastructure.databases.relational import get_relational_engine
+from cognee.modules.users.models import User
+
+
+async def get_unique_data_id(data_identifier: str, user: User) -> UUID:
+ """
+ Function returns a unique UUID for data based on data identifier, user id and tenant id.
+ If data with legacy ID exists, return that ID to maintain compatibility.
+
+ Args:
+ data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
+ user: User object adding the data
+ tenant_id: UUID of the tenant for which data is being added
+
+ Returns:
+ UUID: Unique identifier for the data
+ """
+
+ def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID:
+ """
+ Deprecated function, returns a unique UUID for data based on data identifier and user id.
+ Needed to support legacy data without tenant information.
+ Args:
+ data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
+ user: User object adding the data
+
+ Returns:
+ UUID: Unique identifier for the data
+ """
+ # return UUID hash of file contents + owner id + tenant_id
+ return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}")
+
+ def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID:
+ """
+ Function returns a unique UUID for data based on data identifier, user id and tenant id.
+ Args:
+ data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
+ user: User object adding the data
+ tenant_id: UUID of the tenant for which data is being added
+
+ Returns:
+ UUID: Unique identifier for the data
+ """
+ # return UUID hash of file contents + owner id + tenant_id
+ return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}")
+
+ # Get all possible data_id values
+ data_id = {
+ "modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user),
+ "legacy_data_id": _get_deprecated_unique_data_id(
+ data_identifier=data_identifier, user=user
+ ),
+ }
+
+ # Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id
+ db_engine = get_relational_engine()
+ async with db_engine.get_async_session() as session:
+ legacy_data_point = (
+ await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"]))
+ ).scalar_one_or_none()
+
+ if not legacy_data_point:
+ return data_id["modern_data_id"]
+ return data_id["legacy_data_id"]
diff --git a/cognee/modules/data/methods/get_unique_dataset_id.py b/cognee/modules/data/methods/get_unique_dataset_id.py
index 2caf5fb55..2b765ec78 100644
--- a/cognee/modules/data/methods/get_unique_dataset_id.py
+++ b/cognee/modules/data/methods/get_unique_dataset_id.py
@@ -1,9 +1,71 @@
from uuid import UUID, uuid5, NAMESPACE_OID
-from cognee.modules.users.models import User
from typing import Union
+from sqlalchemy import select
+
+from cognee.modules.data.models.Dataset import Dataset
+from cognee.modules.users.models import User
+from cognee.infrastructure.databases.relational import get_relational_engine
async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
- if isinstance(dataset_name, UUID):
- return dataset_name
- return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
+ """
+ Function returns a unique UUID for dataset based on dataset name, user id and tenant id.
+ If dataset with legacy ID exists, return that ID to maintain compatibility.
+
+ Args:
+ dataset_name: string representing the dataset name
+ user: User object adding the dataset
+ tenant_id: UUID of the tenant for which dataset is being added
+
+ Returns:
+ UUID: Unique identifier for the dataset
+ """
+
+ def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
+ """
+ Legacy function, returns a unique UUID for dataset based on dataset name and user id.
+ Needed to support legacy datasets without tenant information.
+ Args:
+ dataset_name: string representing the dataset name
+ user: Current User object adding the dataset
+
+ Returns:
+ UUID: Unique identifier for the dataset
+ """
+ if isinstance(dataset_name, UUID):
+ return dataset_name
+ return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
+
+ def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
+ """
+ Returns a unique UUID for dataset based on dataset name, user id and tenant_id.
+ Args:
+ dataset_name: string representing the dataset name
+ user: Current User object adding the dataset
+ tenant_id: UUID of the tenant for which dataset is being added
+
+ Returns:
+ UUID: Unique identifier for the dataset
+ """
+ if isinstance(dataset_name, UUID):
+ return dataset_name
+ return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}")
+
+ # Get all possible dataset_id values
+ dataset_id = {
+ "modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user),
+ "legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user),
+ }
+
+ # Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id
+ db_engine = get_relational_engine()
+ async with db_engine.get_async_session() as session:
+ legacy_dataset = (
+ await session.execute(
+ select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"])
+ )
+ ).scalar_one_or_none()
+
+ if not legacy_dataset:
+ return dataset_id["modern_dataset_id"]
+ return dataset_id["legacy_dataset_id"]
diff --git a/cognee/modules/data/models/Dataset.py b/cognee/modules/data/models/Dataset.py
index 797401d5a..fba065253 100644
--- a/cognee/modules/data/models/Dataset.py
+++ b/cognee/modules/data/models/Dataset.py
@@ -18,6 +18,7 @@ class Dataset(Base):
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
owner_id = Column(UUID, index=True)
+ tenant_id = Column(UUID, index=True, nullable=True)
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")
@@ -36,5 +37,6 @@ class Dataset(Base):
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"ownerId": str(self.owner_id),
+ "tenantId": str(self.tenant_id),
"data": [data.to_json() for data in self.data],
}
diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py
new file mode 100644
index 000000000..3381275bd
--- /dev/null
+++ b/cognee/modules/data/processing/document_types/CsvDocument.py
@@ -0,0 +1,33 @@
+import io
+import csv
+from typing import Type
+
+from cognee.modules.chunking.Chunker import Chunker
+from cognee.infrastructure.files.utils.open_data_file import open_data_file
+from .Document import Document
+
+
+class CsvDocument(Document):
+ type: str = "csv"
+ mime_type: str = "text/csv"
+
+ async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
+ async def get_text():
+ async with open_data_file(
+ self.raw_data_location, mode="r", encoding="utf-8", newline=""
+ ) as file:
+ content = file.read()
+ file_like_obj = io.StringIO(content)
+ reader = csv.DictReader(file_like_obj)
+
+ for row in reader:
+ pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
+ row_text = ", ".join(pairs)
+ if not row_text.strip():
+ break
+ yield row_text
+
+ chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
+
+ async for chunk in chunker.read():
+ yield chunk
diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py
index 2e862f4ba..133dd53f8 100644
--- a/cognee/modules/data/processing/document_types/__init__.py
+++ b/cognee/modules/data/processing/document_types/__init__.py
@@ -4,3 +4,4 @@ from .TextDocument import TextDocument
from .ImageDocument import ImageDocument
from .AudioDocument import AudioDocument
from .UnstructuredDocument import UnstructuredDocument
+from .CsvDocument import CsvDocument
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py
index 9703928f0..2e0b82e8d 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraph.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py
@@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
+ async def _get_nodeset_subgraph(
+ self,
+ adapter,
+ node_type,
+ node_name,
+ ):
+ """Retrieve subgraph based on node type and name."""
+ logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
+ nodes_data, edges_data = await adapter.get_nodeset_subgraph(
+ node_type=node_type, node_name=node_name
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(
+ message="Nodeset does not exist, or empty nodeset projected from the database."
+ )
+ return nodes_data, edges_data
+
+ async def _get_full_or_id_filtered_graph(
+ self,
+ adapter,
+ relevant_ids_to_filter,
+ ):
+ """Retrieve full or ID-filtered graph with fallback."""
+ if relevant_ids_to_filter is None:
+ logger.info("Retrieving full graph.")
+ nodes_data, edges_data = await adapter.get_graph_data()
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
+ if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
+ logger.info("Retrieving ID-filtered graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
+ else:
+ logger.info("Retrieving full graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn()
+ if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
+ logger.warning(
+ "Id filtered graph returned empty, falling back to full graph retrieval."
+ )
+ logger.info("Retrieving full graph")
+ nodes_data, edges_data = await adapter.get_graph_data()
+
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError("Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ async def _get_filtered_graph(
+ self,
+ adapter,
+ memory_fragment_filter,
+ ):
+ """Retrieve graph filtered by attributes."""
+ logger.info("Retrieving graph filtered by memory fragment")
+ nodes_data, edges_data = await adapter.get_filtered_graph_data(
+ attribute_filters=memory_fragment_filter
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
+ return nodes_data, edges_data
+
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
+ if node_type is not None and node_name not in [None, [], ""]:
+ nodes_data, edges_data = await self._get_nodeset_subgraph(
+ adapter, node_type, node_name
+ )
+ elif len(memory_fragment_filter) == 0:
+ nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
+ adapter, relevant_ids_to_filter
+ )
+ else:
+ nodes_data, edges_data = await self._get_filtered_graph(
+ adapter, memory_fragment_filter
+ )
+
import time
start_time = time.time()
-
- # Determine projection strategy
- if node_type is not None and node_name not in [None, [], ""]:
- nodes_data, edges_data = await adapter.get_nodeset_subgraph(
- node_type=node_type, node_name=node_name
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Nodeset does not exist, or empty nodetes projected from the database."
- )
- elif len(memory_fragment_filter) == 0:
- nodes_data, edges_data = await adapter.get_graph_data()
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(message="Empty graph projected from the database.")
- else:
- nodes_data, edges_data = await adapter.get_filtered_graph_data(
- attribute_filters=memory_fragment_filter
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Empty filtered graph projected from the database."
- )
-
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
- self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
+ self.add_node(
+ Node(
+ str(node_id),
+ node_attributes,
+ dimension=node_dimension,
+ node_penalty=triplet_distance_penalty,
+ )
+ )
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
+ edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
@@ -171,8 +233,10 @@ class CogneeGraph(CogneeAbstractGraph):
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
- relationship_type = edge.attributes.get("relationship_type")
- distance = embedding_map.get(relationship_type, None)
+ edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
+ "relationship_type"
+ )
+ distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
index 0ca9c4fb9..62ef8d9fd 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
@@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
- self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
+ self,
+ node_id: str,
+ attributes: Optional[Dict[str, Any]] = None,
+ dimension: int = 1,
+ node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
+ edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)
diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py
index 3b01f5af4..c68eb494d 100644
--- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py
+++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py
@@ -1,5 +1,6 @@
from typing import Optional
+from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
@@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)
- # Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
- data_chunk.contains.append(entity_node)
+
+ edge_text = "; ".join(
+ [
+ "relationship_name: contains",
+ f"entity_name: {entity_node.name}",
+ f"entity_description: {entity_node.description}",
+ ]
+ )
+
+ data_chunk.contains.append(
+ (
+ Edge(
+ relationship_type="contains",
+ edge_text=edge_text,
+ ),
+ entity_node,
+ )
+ )
def _process_graph_edges(
diff --git a/cognee/modules/graph/utils/resolve_edges_to_text.py b/cognee/modules/graph/utils/resolve_edges_to_text.py
index eb5bedd2c..5deb13ba8 100644
--- a/cognee/modules/graph/utils/resolve_edges_to_text.py
+++ b/cognee/modules/graph/utils/resolve_edges_to_text.py
@@ -1,71 +1,70 @@
+import string
from typing import List
+from collections import Counter
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
+from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
+
+
+def _get_top_n_frequent_words(
+ text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
+) -> str:
+ """Concatenates the top N frequent words in text."""
+ if stop_words is None:
+ stop_words = DEFAULT_STOP_WORDS
+
+ words = [word.lower().strip(string.punctuation) for word in text.split()]
+ words = [word for word in words if word and word not in stop_words]
+
+ top_words = [word for word, freq in Counter(words).most_common(top_n)]
+ return separator.join(top_words)
+
+
+def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
+ """Creates a title by combining first words with most frequent words from the text."""
+ first_words = text.split()[:first_n_words]
+ top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
+ return f"{' '.join(first_words)}... [{top_words}]"
+
+
+def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
+ """Creates a dictionary of nodes with their names and content."""
+ nodes = {}
+
+ for edge in retrieved_edges:
+ for node in (edge.node1, edge.node2):
+ if node.id in nodes:
+ continue
+
+ text = node.attributes.get("text")
+ if text:
+ name = _create_title_from_text(text)
+ content = text
+ else:
+ name = node.attributes.get("name", "Unnamed Node")
+ content = node.attributes.get("description", name)
+
+ nodes[node.id] = {"node": node, "name": name, "content": content}
+
+ return nodes
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
- """
- Converts retrieved graph edges into a human-readable string format.
+ """Converts retrieved graph edges into a human-readable string format."""
+ nodes = _extract_nodes_from_edges(retrieved_edges)
- Parameters:
- -----------
-
- - retrieved_edges (list): A list of edges retrieved from the graph.
-
- Returns:
- --------
-
- - str: A formatted string representation of the nodes and their connections.
- """
-
- def _get_nodes(retrieved_edges: List[Edge]) -> dict:
- def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
- def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
- """Concatenates the top N frequent words in text."""
- if stop_words is None:
- from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
-
- stop_words = DEFAULT_STOP_WORDS
-
- import string
-
- words = [word.lower().strip(string.punctuation) for word in text.split()]
-
- if stop_words:
- words = [word for word in words if word and word not in stop_words]
-
- from collections import Counter
-
- top_words = [word for word, freq in Counter(words).most_common(top_n)]
-
- return separator.join(top_words)
-
- """Creates a title, by combining first words with most frequent words from the text."""
- first_words = text.split()[:first_n_words]
- top_words = _top_n_words(text, top_n=first_n_words)
- return f"{' '.join(first_words)}... [{top_words}]"
-
- """Creates a dictionary of nodes with their names and content."""
- nodes = {}
- for edge in retrieved_edges:
- for node in (edge.node1, edge.node2):
- if node.id not in nodes:
- text = node.attributes.get("text")
- if text:
- name = _get_title(text)
- content = text
- else:
- name = node.attributes.get("name", "Unnamed Node")
- content = node.attributes.get("description", name)
- nodes[node.id] = {"node": node, "name": name, "content": content}
- return nodes
-
- nodes = _get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
- connection_section = "\n".join(
- f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
- for edge in retrieved_edges
- )
+
+ connections = []
+ for edge in retrieved_edges:
+ source_name = nodes[edge.node1.id]["name"]
+ target_name = nodes[edge.node2.id]["name"]
+ edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
+ connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
+
+ connection_section = "\n".join(connections)
+
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
diff --git a/cognee/modules/ingestion/identify.py b/cognee/modules/ingestion/identify.py
index 977ff3f0b..640fce4a2 100644
--- a/cognee/modules/ingestion/identify.py
+++ b/cognee/modules/ingestion/identify.py
@@ -1,11 +1,11 @@
-from uuid import uuid5, NAMESPACE_OID
+from uuid import UUID
from .data_types import IngestionData
from cognee.modules.users.models import User
+from cognee.modules.data.methods import get_unique_data_id
-def identify(data: IngestionData, user: User) -> str:
+async def identify(data: IngestionData, user: User) -> UUID:
data_content_hash: str = data.get_identifier()
- # return UUID hash of file contents + owner id
- return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
+ return await get_unique_data_id(data_identifier=data_content_hash, user=user)
diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py
index 071deafb7..46499186e 100644
--- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py
+++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py
@@ -2,6 +2,8 @@ import io
import sys
import traceback
+import cognee
+
def wrap_in_async_handler(user_code: str) -> str:
return (
@@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
environment["print"] = customPrintFunction
environment["running_loop"] = loop
+ environment["cognee"] = cognee
try:
exec(code, environment)
diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
index 45e32936a..34d7a946a 100644
--- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
+++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
@@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
-from typing import List, Tuple, Dict, Optional, Any, Union
+from typing import List, Tuple, Dict, Optional, Any, Union, IO
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
- ontology_file: Optional[Union[str, List[str]]] = None,
+ ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
- files_to_load = []
+ self.graph = None
if ontology_file is not None:
- if isinstance(ontology_file, str):
+ files_to_load = []
+ file_objects = []
+
+ if hasattr(ontology_file, "read"):
+ file_objects = [ontology_file]
+ elif isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
- files_to_load = ontology_file
+ if all(hasattr(item, "read") for item in ontology_file):
+ file_objects = ontology_file
+ else:
+ files_to_load = ontology_file
else:
raise ValueError(
- f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
+ f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
)
- if files_to_load:
- self.graph = Graph()
- loaded_files = []
- for file_path in files_to_load:
- if os.path.exists(file_path):
- self.graph.parse(file_path)
- loaded_files.append(file_path)
- logger.info("Ontology loaded successfully from file: %s", file_path)
- else:
- logger.warning(
- "Ontology file '%s' not found. Skipping this file.",
- file_path,
+ if file_objects:
+ self.graph = Graph()
+ loaded_objects = []
+ for file_obj in file_objects:
+ try:
+ content = file_obj.read()
+ self.graph.parse(data=content, format="xml")
+ loaded_objects.append(file_obj)
+ logger.info("Ontology loaded successfully from file object")
+ except Exception as e:
+ logger.warning("Failed to parse ontology file object: %s", str(e))
+
+ if not loaded_objects:
+ logger.info(
+ "No valid ontology file objects found. No owl ontology will be attached to the graph."
)
+ self.graph = None
+ else:
+ logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
- if not loaded_files:
- logger.info(
- "No valid ontology files found. No owl ontology will be attached to the graph."
- )
- self.graph = None
+ elif files_to_load:
+ self.graph = Graph()
+ loaded_files = []
+ for file_path in files_to_load:
+ if os.path.exists(file_path):
+ self.graph.parse(file_path)
+ loaded_files.append(file_path)
+ logger.info("Ontology loaded successfully from file: %s", file_path)
+ else:
+ logger.warning(
+ "Ontology file '%s' not found. Skipping this file.",
+ file_path,
+ )
+
+ if not loaded_files:
+ logger.info(
+ "No valid ontology files found. No owl ontology will be attached to the graph."
+ )
+ self.graph = None
+ else:
+ logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
- logger.info("Total ontology files loaded: %d", len(loaded_files))
+ logger.info(
+ "No ontology file provided. No owl ontology will be attached to the graph."
+ )
else:
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."
diff --git a/cognee/modules/pipelines/operations/run_tasks_data_item.py b/cognee/modules/pipelines/operations/run_tasks_data_item.py
index 152e72d7f..2cc449df6 100644
--- a/cognee/modules/pipelines/operations/run_tasks_data_item.py
+++ b/cognee/modules/pipelines/operations/run_tasks_data_item.py
@@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
- data_id = ingestion.identify(classified_data, user)
+ data_id = await ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py
index 6086977ce..14996f902 100644
--- a/cognee/modules/retrieval/EntityCompletionRetriever.py
+++ b/cognee/modules/retrieval/EntityCompletionRetriever.py
@@ -1,5 +1,5 @@
import asyncio
-from typing import Any, Optional, List
+from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
return None
async def get_completion(
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
- ) -> List[str]:
+ self,
+ query: str,
+ context: Optional[Any] = None,
+ session_id: Optional[str] = None,
+ response_model: Type = str,
+ ) -> List[Any]:
"""
Generate completion using provided context or fetch new context.
@@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
+ - response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
+ response_model=response_model,
),
)
else:
@@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
+ response_model=response_model,
)
if session_save:
diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py
index b0abc2991..b203309ba 100644
--- a/cognee/modules/retrieval/base_graph_retriever.py
+++ b/cognee/modules/retrieval/base_graph_retriever.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod
async def get_completion(
- self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
- ) -> str:
+ self,
+ query: str,
+ context: Optional[List[Edge]] = None,
+ session_id: Optional[str] = None,
+ response_model: Type = str,
+ ) -> List[Any]:
"""Generates a response using the query and optional context (triplets)."""
pass
diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py
index 1533dd44f..b88c741b8 100644
--- a/cognee/modules/retrieval/base_retriever.py
+++ b/cognee/modules/retrieval/base_retriever.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Optional
+from typing import Any, Optional, Type, List
class BaseRetriever(ABC):
@@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod
async def get_completion(
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
- ) -> Any:
+ self,
+ query: str,
+ context: Optional[Any] = None,
+ session_id: Optional[str] = None,
+ response_model: Type = str,
+ ) -> List[Any]:
"""Generates a response using the query and optional context."""
pass
diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py
index bb568924d..126ebcab8 100644
--- a/cognee/modules/retrieval/completion_retriever.py
+++ b/cognee/modules/retrieval/completion_retriever.py
@@ -1,5 +1,5 @@
import asyncio
-from typing import Any, Optional
+from typing import Any, Optional, Type, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
@@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
- ) -> str:
+ self,
+ query: str,
+ context: Optional[Any] = None,
+ session_id: Optional[str] = None,
+ response_model: Type = str,
+ ) -> List[Any]:
"""
Generates an LLM completion using the context.
@@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
+ - response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
+ response_model=response_model,
),
)
else:
@@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
+ response_model=response_model,
)
if session_save:
@@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id,
)
- return completion
+ return [completion]
diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
index 58b6b586f..fc49a139b 100644
--- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
@@ -1,5 +1,5 @@
import asyncio
-from typing import Optional, List, Type
+from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(
@@ -56,7 +60,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
- ) -> List[str]:
+ response_model: Type = str,
+ ) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
completions based on them.
@@ -76,6 +81,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
+ - response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@@ -143,6 +149,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
+ response_model=response_model,
),
)
else:
@@ -152,6 +159,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
+ response_model=response_model,
)
if self.save_interaction and context_text and triplets and completion:
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
index 299db6855..70fcb6cdb 100644
--- a/cognee/modules/retrieval/graph_completion_cot_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import (
- generate_structured_completion,
+ generate_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
@@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- - get_structured_completion
Instance variables include:
- validation_system_prompt_path
@@ -66,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -75,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
@@ -121,7 +124,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
- completion = await generate_structured_completion(
+ completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
@@ -165,24 +168,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets
- async def get_structured_completion(
+ async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
- max_iter: int = 4,
+ max_iter=4,
response_model: Type = str,
- ) -> Any:
+ ) -> List[Any]:
"""
- Generate structured completion responses based on a user query and contextual information.
+ Generate completion responses based on a user query and contextual information.
- This method applies the same chain-of-thought logic as get_completion but returns
+ This method interacts with a language model client to retrieve a structured response,
+ using a series of iterations to refine the answers and generate follow-up questions
+ based on reasoning derived from previous outputs. It raises exceptions if the context
+ retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model.
Parameters:
-----------
+
- query (str): The user's query to be processed and answered.
- - context (Optional[List[Edge]]): Optional context that may assist in answering the query.
+ - context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
@@ -192,7 +199,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns:
--------
- - Any: The generated structured completion based on the response model.
+
+ - List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
@@ -228,45 +236,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
- return completion
-
- async def get_completion(
- self,
- query: str,
- context: Optional[List[Edge]] = None,
- session_id: Optional[str] = None,
- max_iter=4,
- ) -> List[str]:
- """
- Generate completion responses based on a user query and contextual information.
-
- This method interacts with a language model client to retrieve a structured response,
- using a series of iterations to refine the answers and generate follow-up questions
- based on reasoning derived from previous outputs. It raises exceptions if the context
- retrieval fails or if the model encounters issues in generating outputs.
-
- Parameters:
- -----------
-
- - query (str): The user's query to be processed and answered.
- - context (Optional[Any]): Optional context that may assist in answering the query.
- If not provided, it will be fetched based on the query. (default None)
- - session_id (Optional[str]): Optional session identifier for caching. If None,
- defaults to 'default_session'. (default None)
- - max_iter: The maximum number of iterations to refine the answer and generate
- follow-up questions. (default 4)
-
- Returns:
- --------
-
- - List[str]: A list containing the generated answer to the user's query.
- """
- completion = await self.get_structured_completion(
- query=query,
- context=context,
- session_id=session_id,
- max_iter=max_iter,
- response_model=str,
- )
-
return [completion]
diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py
index b7ab4edae..89e9e47ce 100644
--- a/cognee/modules/retrieval/graph_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_retriever.py
@@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
+ self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
+ self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
+ wide_search_top_k=self.wide_search_top_k,
+ triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@@ -141,12 +147,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return triplets
+ async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
+ context = await self.resolve_edges_to_text(triplets)
+ return context
+
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
- ) -> List[str]:
+ response_model: Type = str,
+ ) -> List[Any]:
"""
Generates a completion using graph connections context based on a query.
@@ -188,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
+ response_model=response_model,
),
)
else:
@@ -197,6 +209,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
+ response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py
index 051f39b22..e31ad126e 100644
--- a/cognee/modules/retrieval/graph_summary_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py
@@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path
diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py
index ec68d37bb..87d2ab009 100644
--- a/cognee/modules/retrieval/temporal_retriever.py
+++ b/cognee/modules/retrieval/temporal_retriever.py
@@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
@@ -146,8 +150,12 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(
- self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
- ) -> List[str]:
+ self,
+ query: str,
+ context: Optional[str] = None,
+ session_id: Optional[str] = None,
+ response_model: Type = str,
+ ) -> List[Any]:
"""
Generates a response using the query and optional context.
@@ -159,6 +167,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
+ - response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@@ -186,6 +195,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
+ response_model=response_model,
),
)
else:
@@ -194,6 +204,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
+ response_model=response_model,
)
if session_save:
diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
index 1ef7545c2..2f8a545f7 100644
--- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py
+++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
@@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@@ -71,9 +73,11 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
- edge_properties_to_project=["relationship_name"],
+ edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
+ triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
- if memory_fragment is None:
- memory_fragment = await get_memory_fragment(
- properties_to_project, node_type=node_type, node_name=node_name
- )
+ # Setting wide search limit based on the parameters
+ non_global_search = node_name is None
+
+ wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
- collection_name=collection_name, query_vector=query_vector, limit=None
+ collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
- projection_time = time.time() - start_time
+ vector_collection_search_time = time.time() - start_time
logger.info(
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
+ if wide_search_limit is not None:
+ relevant_ids_to_filter = list(
+ {
+ str(getattr(scored_node, "id"))
+ for collection_name, score_collection in node_distances.items()
+ if collection_name != "EdgeType_relationship_name"
+ and isinstance(score_collection, (list, tuple))
+ for scored_node in score_collection
+ if getattr(scored_node, "id", None)
+ }
+ )
+ else:
+ relevant_ids_to_filter = None
+
+ if memory_fragment is None:
+ memory_fragment = await get_memory_fragment(
+ properties_to_project=properties_to_project,
+ node_type=node_type,
+ node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
+ )
+
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py
index db7a10252..c90ce77f4 100644
--- a/cognee/modules/retrieval/utils/completion.py
+++ b/cognee/modules/retrieval/utils/completion.py
@@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
-async def generate_structured_completion(
+async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
@@ -12,7 +12,7 @@ async def generate_structured_completion(
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
- """Generates a structured completion using LLM with given context and prompts."""
+ """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@@ -28,26 +28,6 @@ async def generate_structured_completion(
)
-async def generate_completion(
- query: str,
- context: str,
- user_prompt_path: str,
- system_prompt_path: str,
- system_prompt: Optional[str] = None,
- conversation_history: Optional[str] = None,
-) -> str:
- """Generates a completion using LLM with given context and prompts."""
- return await generate_structured_completion(
- query=query,
- context=context,
- user_prompt_path=user_prompt_path,
- system_prompt_path=system_prompt_path,
- system_prompt=system_prompt,
- conversation_history=conversation_history,
- response_model=str,
- )
-
-
async def summarize_text(
text: str,
system_prompt_path: str = "summarize_search_results.txt",
diff --git a/cognee/modules/run_custom_pipeline/__init__.py b/cognee/modules/run_custom_pipeline/__init__.py
new file mode 100644
index 000000000..2d30e2e0c
--- /dev/null
+++ b/cognee/modules/run_custom_pipeline/__init__.py
@@ -0,0 +1 @@
+from .run_custom_pipeline import run_custom_pipeline
diff --git a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py
new file mode 100644
index 000000000..d3df1c060
--- /dev/null
+++ b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py
@@ -0,0 +1,69 @@
+from typing import Union, Optional, List, Type, Any
+from uuid import UUID
+
+from cognee.shared.logging_utils import get_logger
+
+from cognee.modules.pipelines import run_pipeline
+from cognee.modules.pipelines.tasks.task import Task
+from cognee.modules.users.models import User
+from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
+
+logger = get_logger()
+
+
+async def run_custom_pipeline(
+ tasks: Union[List[Task], List[str]] = None,
+ data: Any = None,
+ dataset: Union[str, UUID] = "main_dataset",
+ user: User = None,
+ vector_db_config: Optional[dict] = None,
+ graph_db_config: Optional[dict] = None,
+ data_per_batch: int = 20,
+ run_in_background: bool = False,
+ pipeline_name: str = "custom_pipeline",
+):
+ """
+ Custom pipeline in Cognee, can work with already built graphs. Data needs to be provided which can be processed
+ with provided tasks.
+
+ Provided tasks and data will be arranged to run the Cognee pipeline and execute graph enrichment/creation.
+
+ This is the core processing step in Cognee that converts raw text and documents
+ into an intelligent knowledge graph. It analyzes content, extracts entities and
+ relationships, and creates semantic connections for enhanced search and reasoning.
+
+ Args:
+ tasks: List of Cognee Tasks to execute.
+ data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used.
+ Data provided here will be forwarded to the first extraction task in the pipeline as input.
+ dataset: Dataset name or dataset uuid to process.
+ user: User context for authentication and data access. Uses default if None.
+ vector_db_config: Custom vector database configuration for embeddings storage.
+ graph_db_config: Custom graph database configuration for relationship storage.
+ data_per_batch: Number of data items to be processed in parallel.
+ run_in_background: If True, starts processing asynchronously and returns immediately.
+ If False, waits for completion before returning.
+ Background mode recommended for large datasets (>100MB).
+ Use pipeline_run_id from return value to monitor progress.
+ """
+
+ custom_tasks = [
+ *tasks,
+ ]
+
+ # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
+ pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
+
+ # Run the run_pipeline in the background or blocking based on executor
+ return await pipeline_executor_func(
+ pipeline=run_pipeline,
+ tasks=custom_tasks,
+ user=user,
+ data=data,
+ datasets=dataset,
+ vector_db_config=vector_db_config,
+ graph_db_config=graph_db_config,
+ incremental_loading=False,
+ data_per_batch=data_per_batch,
+ pipeline_name=pipeline_name,
+ )
diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py
index 72e2db89a..165ec379b 100644
--- a/cognee/modules/search/methods/get_search_type_tools.py
+++ b/cognee/modules/search/methods/get_search_type_tools.py
@@ -37,6 +37,8 @@ async def get_search_type_tools(
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
@@ -67,6 +69,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -75,6 +79,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
@@ -85,6 +91,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@@ -93,6 +101,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@@ -103,6 +113,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@@ -111,6 +123,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
@@ -121,6 +135,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -129,6 +145,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
@@ -145,8 +163,16 @@ async def get_search_type_tools(
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
- TemporalRetriever(top_k=top_k).get_completion,
- TemporalRetriever(top_k=top_k).get_context,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_completion,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_context,
],
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [
diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py
index fcb02da46..3a703bbc9 100644
--- a/cognee/modules/search/methods/no_access_control_search.py
+++ b/cognee/modules/search/methods/no_access_control_search.py
@@ -24,6 +24,8 @@ async def no_access_control_search(
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@@ -35,6 +37,8 @@ async def no_access_control_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index aab004924..9f180d607 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -1,4 +1,3 @@
-import os
import json
import asyncio
from uuid import UUID
@@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
+from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@@ -47,6 +47,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@@ -74,7 +76,7 @@ async def search(
)
# Use search function filtered by permissions if access control is enabled
- if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
+ if backend_access_control_enabled():
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
@@ -90,6 +92,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
else:
search_results = [
@@ -105,6 +109,8 @@ async def search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
]
@@ -156,7 +162,7 @@ async def search(
)
else:
# This is for maintaining backwards compatibility
- if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
+ if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
@@ -172,6 +178,7 @@ async def search(
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
+ "dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
@@ -181,6 +188,7 @@ async def search(
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
+ "dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
@@ -217,6 +225,8 @@ async def authorized_search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@@ -244,6 +254,8 @@ async def authorized_search(
last_k=last_k,
only_context=True,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
@@ -265,6 +277,8 @@ async def authorized_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -304,6 +318,7 @@ async def authorized_search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
)
return search_results
@@ -323,6 +338,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@@ -343,6 +360,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@@ -376,6 +395,8 @@ async def search_in_datasets_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -411,6 +432,8 @@ async def search_in_datasets_context(
only_context=only_context,
context=context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
)
diff --git a/cognee/modules/users/methods/create_user.py b/cognee/modules/users/methods/create_user.py
index 1b303bd36..ef325fb6f 100644
--- a/cognee/modules/users/methods/create_user.py
+++ b/cognee/modules/users/methods/create_user.py
@@ -18,7 +18,6 @@ from typing import Optional
async def create_user(
email: str,
password: str,
- tenant_id: Optional[str] = None,
is_superuser: bool = False,
is_active: bool = True,
is_verified: bool = False,
@@ -30,37 +29,23 @@ async def create_user(
async with relational_engine.get_async_session() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
- if tenant_id:
- # Check if the tenant already exists
- result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
- tenant = result.scalars().first()
- if not tenant:
- raise TenantNotFoundError
-
- user = await user_manager.create(
- UserCreate(
- email=email,
- password=password,
- tenant_id=tenant.id,
- is_superuser=is_superuser,
- is_active=is_active,
- is_verified=is_verified,
- )
- )
- else:
- user = await user_manager.create(
- UserCreate(
- email=email,
- password=password,
- is_superuser=is_superuser,
- is_active=is_active,
- is_verified=is_verified,
- )
+ user = await user_manager.create(
+ UserCreate(
+ email=email,
+ password=password,
+ is_superuser=is_superuser,
+ is_active=is_active,
+ is_verified=is_verified,
)
+ )
if auto_login:
await session.refresh(user)
+ # Update tenants and roles information for User object
+ _ = await user.awaitable_attrs.tenants
+ _ = await user.awaitable_attrs.roles
+
return user
except UserAlreadyExists as error:
print(f"User {email} already exists")
diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py
index d78215892..d6d701737 100644
--- a/cognee/modules/users/methods/get_authenticated_user.py
+++ b/cognee/modules/users/methods/get_authenticated_user.py
@@ -5,6 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger
+from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user")
@@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
- or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
+ or backend_access_control_enabled()
)
fastapi_users = get_fastapi_users()
diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py
index 773545f8e..8dc364f32 100644
--- a/cognee/modules/users/methods/get_default_user.py
+++ b/cognee/modules/users/methods/get_default_user.py
@@ -10,7 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.methods.create_default_user import create_default_user
-async def get_default_user() -> SimpleNamespace:
+async def get_default_user() -> User:
db_engine = get_relational_engine()
base_config = get_base_config()
default_email = base_config.default_user_email or "default_user@example.com"
@@ -18,7 +18,9 @@ async def get_default_user() -> SimpleNamespace:
try:
async with db_engine.get_async_session() as session:
query = (
- select(User).options(selectinload(User.roles)).where(User.email == default_email)
+ select(User)
+ .options(selectinload(User.roles), selectinload(User.tenants))
+ .where(User.email == default_email)
)
result = await session.execute(query)
diff --git a/cognee/modules/users/methods/get_user.py b/cognee/modules/users/methods/get_user.py
index 2678a5a01..a1c87aab7 100644
--- a/cognee/modules/users/methods/get_user.py
+++ b/cognee/modules/users/methods/get_user.py
@@ -14,7 +14,7 @@ async def get_user(user_id: UUID):
user = (
await session.execute(
select(User)
- .options(selectinload(User.roles), selectinload(User.tenant))
+ .options(selectinload(User.roles), selectinload(User.tenants))
.where(User.id == user_id)
)
).scalar()
diff --git a/cognee/modules/users/methods/get_user_by_email.py b/cognee/modules/users/methods/get_user_by_email.py
index c4bd5b48e..6df989251 100644
--- a/cognee/modules/users/methods/get_user_by_email.py
+++ b/cognee/modules/users/methods/get_user_by_email.py
@@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str):
user = (
await session.execute(
select(User)
- .options(joinedload(User.roles), joinedload(User.tenant))
+ .options(joinedload(User.roles), joinedload(User.tenants))
.where(User.email == user_email)
)
).scalar()
diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py
index 0d71d8413..25d610ab9 100644
--- a/cognee/modules/users/models/DatasetDatabase.py
+++ b/cognee/modules/users/models/DatasetDatabase.py
@@ -15,5 +15,14 @@ class DatasetDatabase(Base):
vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False)
+ vector_database_provider = Column(String, unique=False, nullable=False)
+ graph_database_provider = Column(String, unique=False, nullable=False)
+
+ vector_database_url = Column(String, unique=False, nullable=True)
+ graph_database_url = Column(String, unique=False, nullable=True)
+
+ vector_database_key = Column(String, unique=False, nullable=True)
+ graph_database_key = Column(String, unique=False, nullable=True)
+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
diff --git a/cognee/modules/users/models/Tenant.py b/cognee/modules/users/models/Tenant.py
index 95023a6ee..b8fa158c5 100644
--- a/cognee/modules/users/models/Tenant.py
+++ b/cognee/modules/users/models/Tenant.py
@@ -1,7 +1,7 @@
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import Column, String, ForeignKey, UUID
from .Principal import Principal
-from .User import User
+from .UserTenant import UserTenant
from .Role import Role
@@ -13,14 +13,13 @@ class Tenant(Principal):
owner_id = Column(UUID, index=True)
- # One-to-Many relationship with User; specify the join via User.tenant_id
- users = relationship(
+ users: Mapped[list["User"]] = relationship( # noqa: F821
"User",
- back_populates="tenant",
- foreign_keys=lambda: [User.tenant_id],
+ secondary=UserTenant.__tablename__,
+ back_populates="tenants",
)
- # One-to-Many relationship with Role (if needed; similar fix)
+ # One-to-Many relationship with Role
roles = relationship(
"Role",
back_populates="tenant",
diff --git a/cognee/modules/users/models/User.py b/cognee/modules/users/models/User.py
index 8972a5932..a98abd3bc 100644
--- a/cognee/modules/users/models/User.py
+++ b/cognee/modules/users/models/User.py
@@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID
from sqlalchemy.orm import relationship, Mapped
from .Principal import Principal
+from .UserTenant import UserTenant
from .UserRole import UserRole
from .Role import Role
+from .Tenant import Tenant
class User(SQLAlchemyBaseUserTableUUID, Principal):
@@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
- # Foreign key to Tenant (Many-to-One relationship)
+ # Foreign key to current Tenant (Many-to-One relationship)
tenant_id = Column(UUID, ForeignKey("tenants.id"))
# Many-to-Many Relationship with Roles
@@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
back_populates="users",
)
- # Relationship to Tenant
- tenant = relationship(
+ # Many-to-Many Relationship with Tenants user is a part of
+ tenants: Mapped[list["Tenant"]] = relationship(
"Tenant",
+ secondary=UserTenant.__tablename__,
back_populates="users",
- foreign_keys=[tenant_id],
)
# ACL Relationship (One-to-Many)
@@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]):
class UserCreate(schemas.BaseUserCreate):
- tenant_id: Optional[uuid_UUID] = None
is_verified: bool = True
diff --git a/cognee/modules/users/models/UserTenant.py b/cognee/modules/users/models/UserTenant.py
new file mode 100644
index 000000000..bfb852aa5
--- /dev/null
+++ b/cognee/modules/users/models/UserTenant.py
@@ -0,0 +1,12 @@
+from datetime import datetime, timezone
+from sqlalchemy import Column, ForeignKey, DateTime, UUID
+from cognee.infrastructure.databases.relational import Base
+
+
+class UserTenant(Base):
+ __tablename__ = "user_tenants"
+
+ created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
+
+ user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
+ tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True)
diff --git a/cognee/modules/users/models/__init__.py b/cognee/modules/users/models/__init__.py
index ba2f40e49..5114cc45a 100644
--- a/cognee/modules/users/models/__init__.py
+++ b/cognee/modules/users/models/__init__.py
@@ -1,6 +1,7 @@
from .User import User
from .Role import Role
from .UserRole import UserRole
+from .UserTenant import UserTenant
from .DatasetDatabase import DatasetDatabase
from .RoleDefaultPermissions import RoleDefaultPermissions
from .UserDefaultPermissions import UserDefaultPermissions
diff --git a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py
index 1185dd7ad..5eed992db 100644
--- a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py
+++ b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py
@@ -1,11 +1,8 @@
-from types import SimpleNamespace
-
from cognee.shared.logging_utils import get_logger
from ...models.User import User
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.permissions.methods import get_principal_datasets
-from cognee.modules.users.permissions.methods import get_role, get_tenant
logger = get_logger()
@@ -25,17 +22,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# Get all datasets User has explicit access to
datasets.extend(await get_principal_datasets(user, permission_type))
- if user.tenant_id:
- # Get all datasets all tenants have access to
- tenant = await get_tenant(user.tenant_id)
+ # Get all tenants user is a part of
+ tenants = await user.awaitable_attrs.tenants
+ for tenant in tenants:
+ # Get all datasets all tenant members have access to
datasets.extend(await get_principal_datasets(tenant, permission_type))
- # Get all datasets Users roles have access to
- if isinstance(user, SimpleNamespace):
- # If simple namespace use roles defined in user
- roles = user.roles
- else:
- roles = await user.awaitable_attrs.roles
+ # Get all datasets accessible by roles user is a part of
+ roles = await user.awaitable_attrs.roles
for role in roles:
datasets.extend(await get_principal_datasets(role, permission_type))
@@ -45,4 +39,10 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# If the dataset id key already exists, leave the dictionary unchanged.
unique.setdefault(dataset.id, dataset)
- return list(unique.values())
+ # Filter out dataset that aren't part of the selected user's tenant
+ filtered_datasets = []
+ for dataset in list(unique.values()):
+ if dataset.tenant_id == user.tenant_id:
+ filtered_datasets.append(dataset)
+
+ return filtered_datasets
diff --git a/cognee/modules/users/roles/methods/add_user_to_role.py b/cognee/modules/users/roles/methods/add_user_to_role.py
index de5e47775..23bb947f0 100644
--- a/cognee/modules/users/roles/methods/add_user_to_role.py
+++ b/cognee/modules/users/roles/methods/add_user_to_role.py
@@ -42,11 +42,13 @@ async def add_user_to_role(user_id: UUID, role_id: UUID, owner_id: UUID):
.first()
)
+ user_tenants = await user.awaitable_attrs.tenants
+
if not user:
raise UserNotFoundError
elif not role:
raise RoleNotFoundError
- elif user.tenant_id != role.tenant_id:
+ elif role.tenant_id not in [tenant.id for tenant in user_tenants]:
raise TenantNotFoundError(
message="User tenant does not match role tenant. User cannot be added to role."
)
diff --git a/cognee/modules/users/tenants/methods/__init__.py b/cognee/modules/users/tenants/methods/__init__.py
index 9a052e9c6..39e2b31bb 100644
--- a/cognee/modules/users/tenants/methods/__init__.py
+++ b/cognee/modules/users/tenants/methods/__init__.py
@@ -1,2 +1,3 @@
from .create_tenant import create_tenant
from .add_user_to_tenant import add_user_to_tenant
+from .select_tenant import select_tenant
diff --git a/cognee/modules/users/tenants/methods/add_user_to_tenant.py b/cognee/modules/users/tenants/methods/add_user_to_tenant.py
index 1374067a7..eecc49f6f 100644
--- a/cognee/modules/users/tenants/methods/add_user_to_tenant.py
+++ b/cognee/modules/users/tenants/methods/add_user_to_tenant.py
@@ -1,8 +1,11 @@
+from typing import Optional
from uuid import UUID
from sqlalchemy.exc import IntegrityError
+from sqlalchemy import insert
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine
+from cognee.modules.users.models.UserTenant import UserTenant
from cognee.modules.users.methods import get_user
from cognee.modules.users.permissions.methods import get_tenant
from cognee.modules.users.exceptions import (
@@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import (
)
-async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
+async def add_user_to_tenant(
+ user_id: UUID, tenant_id: UUID, owner_id: UUID, set_as_active_tenant: Optional[bool] = False
+):
"""
Add a user with the given id to the tenant with the given id.
This can only be successful if the request owner with the given id is the tenant owner.
+
+ If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
Args:
user_id: Id of the user.
tenant_id: Id of the tenant.
owner_id: Id of the request owner.
+ set_as_active_tenant: If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
Returns:
None
@@ -40,17 +48,18 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
message="Only tenant owner can add other users to organization."
)
- try:
- if user.tenant_id is None:
- user.tenant_id = tenant_id
- elif user.tenant_id == tenant_id:
- return
- else:
- raise IntegrityError
-
+ if set_as_active_tenant:
+ user.tenant_id = tenant_id
await session.merge(user)
await session.commit()
- except IntegrityError:
- raise EntityAlreadyExistsError(
- message="User is already part of a tenant. Only one tenant can be assigned to user."
+
+ try:
+ # Add association directly to the association table
+ create_user_tenant_statement = insert(UserTenant).values(
+ user_id=user_id, tenant_id=tenant_id
)
+ await session.execute(create_user_tenant_statement)
+ await session.commit()
+
+ except IntegrityError:
+ raise EntityAlreadyExistsError(message="User is already part of group.")
diff --git a/cognee/modules/users/tenants/methods/create_tenant.py b/cognee/modules/users/tenants/methods/create_tenant.py
index bfd23e08f..32baa05fd 100644
--- a/cognee/modules/users/tenants/methods/create_tenant.py
+++ b/cognee/modules/users/tenants/methods/create_tenant.py
@@ -1,19 +1,25 @@
from uuid import UUID
+from sqlalchemy import insert
from sqlalchemy.exc import IntegrityError
+from typing import Optional
+from cognee.modules.users.models.UserTenant import UserTenant
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models import Tenant
from cognee.modules.users.methods import get_user
-async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
+async def create_tenant(
+ tenant_name: str, user_id: UUID, set_as_active_tenant: Optional[bool] = True
+) -> UUID:
"""
Create a new tenant with the given name, for the user with the given id.
This user is the owner of the tenant.
Args:
tenant_name: Name of the new tenant.
user_id: Id of the user.
+ set_as_active_tenant: If true, set the newly created tenant as the active tenant for the user.
Returns:
None
@@ -22,18 +28,26 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
async with db_engine.get_async_session() as session:
try:
user = await get_user(user_id)
- if user.tenant_id:
- raise EntityAlreadyExistsError(
- message="User already has a tenant. New tenant cannot be created."
- )
tenant = Tenant(name=tenant_name, owner_id=user_id)
session.add(tenant)
await session.flush()
- user.tenant_id = tenant.id
- await session.merge(user)
- await session.commit()
+ if set_as_active_tenant:
+ user.tenant_id = tenant.id
+ await session.merge(user)
+ await session.commit()
+
+ try:
+ # Add association directly to the association table
+ create_user_tenant_statement = insert(UserTenant).values(
+ user_id=user_id, tenant_id=tenant.id
+ )
+ await session.execute(create_user_tenant_statement)
+ await session.commit()
+ except IntegrityError:
+ raise EntityAlreadyExistsError(message="User is already part of tenant.")
+
return tenant.id
except IntegrityError as e:
raise EntityAlreadyExistsError(message="Tenant already exists.") from e
diff --git a/cognee/modules/users/tenants/methods/select_tenant.py b/cognee/modules/users/tenants/methods/select_tenant.py
new file mode 100644
index 000000000..83c11dc91
--- /dev/null
+++ b/cognee/modules/users/tenants/methods/select_tenant.py
@@ -0,0 +1,62 @@
+from uuid import UUID
+from typing import Union
+
+import sqlalchemy.exc
+from sqlalchemy import select
+
+from cognee.infrastructure.databases.relational import get_relational_engine
+from cognee.modules.users.methods.get_user import get_user
+from cognee.modules.users.models.UserTenant import UserTenant
+from cognee.modules.users.models.User import User
+from cognee.modules.users.permissions.methods import get_tenant
+from cognee.modules.users.exceptions import UserNotFoundError, TenantNotFoundError
+
+
+async def select_tenant(user_id: UUID, tenant_id: Union[UUID, None]) -> User:
+ """
+ Set the users active tenant to provided tenant.
+
+ If None tenant_id is provided set current Tenant to the default single user-tenant
+ Args:
+ user_id: UUID of the user.
+ tenant_id: Id of the tenant.
+
+ Returns:
+ None
+
+ """
+ db_engine = get_relational_engine()
+ async with db_engine.get_async_session() as session:
+ user = await get_user(user_id)
+ if tenant_id is None:
+ # If no tenant_id is provided set current Tenant to the single user-tenant
+ user.tenant_id = None
+ await session.merge(user)
+ await session.commit()
+ return user
+
+ tenant = await get_tenant(tenant_id)
+
+ if not user:
+ raise UserNotFoundError
+ elif not tenant:
+ raise TenantNotFoundError
+
+ # Check if User is part of Tenant
+ result = await session.execute(
+ select(UserTenant)
+ .where(UserTenant.user_id == user.id)
+ .where(UserTenant.tenant_id == tenant_id)
+ )
+
+ try:
+ result = result.scalar_one()
+ except sqlalchemy.exc.NoResultFound as e:
+ raise TenantNotFoundError("User is not part of the tenant.") from e
+
+ if result:
+ # If user is part of tenant update current tenant of user
+ user.tenant_id = tenant_id
+ await session.merge(user)
+ await session.commit()
+ return user
diff --git a/cognee/shared/logging_utils.py b/cognee/shared/logging_utils.py
index 0e5120b1d..e8efde72c 100644
--- a/cognee/shared/logging_utils.py
+++ b/cognee/shared/logging_utils.py
@@ -450,6 +450,8 @@ def setup_logging(log_level=None, name=None):
try:
msg = self.format(record)
stream = self.stream
+ if hasattr(stream, "closed") and stream.closed:
+ return
stream.write("\n" + msg + self.terminator)
self.flush()
except Exception:
diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py
index 22ce96be8..37d4de73e 100644
--- a/cognee/tasks/chunks/__init__.py
+++ b/cognee/tasks/chunks/__init__.py
@@ -1,4 +1,5 @@
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph
+from .chunk_by_row import chunk_by_row
from .remove_disconnected_chunks import remove_disconnected_chunks
diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py
new file mode 100644
index 000000000..8daf13689
--- /dev/null
+++ b/cognee/tasks/chunks/chunk_by_row.py
@@ -0,0 +1,94 @@
+from typing import Any, Dict, Iterator
+from uuid import NAMESPACE_OID, uuid5
+
+from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
+
+
+def _get_pair_size(pair_text: str) -> int:
+ """
+ Calculate the size of a given text in terms of tokens.
+
+ If an embedding engine's tokenizer is available, count the tokens for the provided word.
+ If the tokenizer is not available, assume the word counts as one token.
+
+ Parameters:
+ -----------
+
+ - pair_text (str): The key:value pair text for which the token size is to be calculated.
+
+ Returns:
+ --------
+
+ - int: The number of tokens representing the text, typically an integer, depending
+ on the tokenizer's output.
+ """
+ embedding_engine = get_embedding_engine()
+ if embedding_engine.tokenizer:
+ return embedding_engine.tokenizer.count_tokens(pair_text)
+ else:
+ return 3
+
+
+def chunk_by_row(
+ data: str,
+ max_chunk_size,
+) -> Iterator[Dict[str, Any]]:
+ """
+ Chunk the input text by row while enabling exact text reconstruction.
+
+ This function divides the given text data into smaller chunks on a line-by-line basis,
+ ensuring that the size of each chunk is less than or equal to the specified maximum
+ chunk size. It guarantees that when the generated chunks are concatenated, they
+ reproduce the original text accurately. The tokenization process is handled by
+ adapters compatible with the vector engine's embedding model.
+
+ Parameters:
+ -----------
+
+ - data (str): The input text to be chunked.
+ - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
+ words.
+ """
+ current_chunk_list = []
+ chunk_index = 0
+ current_chunk_size = 0
+
+ lines = data.split("\n\n")
+ for line in lines:
+ pairs_text = line.split(", ")
+
+ for pair_text in pairs_text:
+ pair_size = _get_pair_size(pair_text)
+ if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
+ # Yield current cut chunk
+ current_chunk = ", ".join(current_chunk_list)
+ chunk_dict = {
+ "text": current_chunk,
+ "chunk_size": current_chunk_size,
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
+ "chunk_index": chunk_index,
+ "cut_type": "row_cut",
+ }
+
+ yield chunk_dict
+
+ # Start new chunk with current pair text
+ current_chunk_list = []
+ current_chunk_size = 0
+ chunk_index += 1
+
+ current_chunk_list.append(pair_text)
+ current_chunk_size += pair_size
+
+ # Yield row chunk
+ current_chunk = ", ".join(current_chunk_list)
+ if current_chunk:
+ chunk_dict = {
+ "text": current_chunk,
+ "chunk_size": current_chunk_size,
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
+ "chunk_index": chunk_index,
+ "cut_type": "row_end",
+ }
+
+ yield chunk_dict
diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py
index 9fa512906..e4f13ebd1 100644
--- a/cognee/tasks/documents/classify_documents.py
+++ b/cognee/tasks/documents/classify_documents.py
@@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
ImageDocument,
TextDocument,
UnstructuredDocument,
+ CsvDocument,
)
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.engine.utils.generate_node_id import generate_node_id
@@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents
"txt": TextDocument,
+ "csv": CsvDocument,
"docx": UnstructuredDocument,
"doc": UnstructuredDocument,
"odt": UnstructuredDocument,
diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py
index e439cf9e5..d2b143d29 100644
--- a/cognee/tasks/feedback/generate_improved_answers.py
+++ b/cognee/tasks/feedback/generate_improved_answers.py
@@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
)
retrieved_context = await retriever.get_context(query_text)
- completion = await retriever.get_structured_completion(
+ completion = await retriever.get_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
@@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
- enrichment.improved_answer = completion.answer
+ enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text
- enrichment.explanation = completion.explanation
+ enrichment.explanation = completion[0].explanation
return enrichment
else:
logger.warning(
diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py
index 0572d0f1e..5987f38d5 100644
--- a/cognee/tasks/ingestion/ingest_data.py
+++ b/cognee/tasks/ingestion/ingest_data.py
@@ -99,7 +99,7 @@ async def ingest_data(
# data_id is the hash of original file contents + owner id to avoid duplicate data
- data_id = ingestion.identify(classified_data, user)
+ data_id = await ingestion.identify(classified_data, user)
original_file_metadata = classified_data.get_metadata()
# Find metadata from Cognee data storage text file
diff --git a/cognee/tasks/memify/__init__.py b/cognee/tasks/memify/__init__.py
index 692bac443..7e590ed47 100644
--- a/cognee/tasks/memify/__init__.py
+++ b/cognee/tasks/memify/__init__.py
@@ -1,2 +1,4 @@
from .extract_subgraph import extract_subgraph
from .extract_subgraph_chunks import extract_subgraph_chunks
+from .cognify_session import cognify_session
+from .extract_user_sessions import extract_user_sessions
diff --git a/cognee/tasks/memify/cognify_session.py b/cognee/tasks/memify/cognify_session.py
new file mode 100644
index 000000000..f53f9afb1
--- /dev/null
+++ b/cognee/tasks/memify/cognify_session.py
@@ -0,0 +1,41 @@
+import cognee
+
+from cognee.exceptions import CogneeValidationError, CogneeSystemError
+from cognee.shared.logging_utils import get_logger
+
+logger = get_logger("cognify_session")
+
+
+async def cognify_session(data, dataset_id=None):
+ """
+ Process and cognify session data into the knowledge graph.
+
+ Adds session content to cognee with a dedicated "user_sessions" node set,
+ then triggers the cognify pipeline to extract entities and relationships
+ from the session data.
+
+ Args:
+ data: Session string containing Question, Context, and Answer information.
+ dataset_name: Name of dataset.
+
+ Raises:
+ CogneeValidationError: If data is None or empty.
+ CogneeSystemError: If cognee operations fail.
+ """
+ try:
+ if not data or (isinstance(data, str) and not data.strip()):
+ logger.warning("Empty session data provided to cognify_session task, skipping")
+ raise CogneeValidationError(message="Session data cannot be empty", log=False)
+
+ logger.info("Processing session data for cognification")
+
+ await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"])
+ logger.debug("Session data added to cognee with node_set: user_sessions")
+ await cognee.cognify(datasets=[dataset_id])
+ logger.info("Session data successfully cognified")
+
+ except CogneeValidationError:
+ raise
+ except Exception as e:
+ logger.error(f"Error cognifying session data: {str(e)}")
+ raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False)
diff --git a/cognee/tasks/memify/extract_user_sessions.py b/cognee/tasks/memify/extract_user_sessions.py
new file mode 100644
index 000000000..9779a363e
--- /dev/null
+++ b/cognee/tasks/memify/extract_user_sessions.py
@@ -0,0 +1,73 @@
+from typing import Optional, List
+
+from cognee.context_global_variables import session_user
+from cognee.exceptions import CogneeSystemError
+from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
+from cognee.shared.logging_utils import get_logger
+from cognee.modules.users.models import User
+
+logger = get_logger("extract_user_sessions")
+
+
+async def extract_user_sessions(
+ data,
+ session_ids: Optional[List[str]] = None,
+):
+ """
+ Extract Q&A sessions for the current user from cache.
+
+ Retrieves all Q&A triplets from specified session IDs and yields them
+ as formatted strings combining question, context, and answer.
+
+ Args:
+ data: Data passed from memify. If empty dict ({}), no external data is provided.
+ session_ids: Optional list of specific session IDs to extract.
+
+ Yields:
+ String containing session ID and all Q&A pairs formatted.
+
+ Raises:
+ CogneeSystemError: If cache engine is unavailable or extraction fails.
+ """
+ try:
+ if not data or data == [{}]:
+ logger.info("Fetching session metadata for current user")
+
+ user: User = session_user.get()
+ if not user:
+ raise CogneeSystemError(message="No authenticated user found in context", log=False)
+
+ user_id = str(user.id)
+
+ cache_engine = get_cache_engine()
+ if cache_engine is None:
+ raise CogneeSystemError(
+ message="Cache engine not available for session extraction, please enable caching in order to have sessions to save",
+ log=False,
+ )
+
+ if session_ids:
+ for session_id in session_ids:
+ try:
+ qa_data = await cache_engine.get_all_qas(user_id, session_id)
+ if qa_data:
+ logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs")
+ session_string = f"Session ID: {session_id}\n\n"
+ for qa_pair in qa_data:
+ question = qa_pair.get("question", "")
+ answer = qa_pair.get("answer", "")
+ session_string += f"Question: {question}\n\nAnswer: {answer}\n\n"
+ yield session_string
+ except Exception as e:
+ logger.warning(f"Failed to extract session {session_id}: {str(e)}")
+ continue
+ else:
+ logger.info(
+ "No specific session_ids provided. Please specify which sessions to extract."
+ )
+
+ except CogneeSystemError:
+ raise
+ except Exception as e:
+ logger.error(f"Error extracting user sessions: {str(e)}")
+ raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False)
diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py
index 902789c80..b0ec3a5b4 100644
--- a/cognee/tasks/storage/index_data_points.py
+++ b/cognee/tasks/storage/index_data_points.py
@@ -8,47 +8,58 @@ logger = get_logger("index_data_points")
async def index_data_points(data_points: list[DataPoint]):
- created_indexes = {}
- index_points = {}
+ """Index data points in the vector engine by creating embeddings for specified fields.
+
+ Process:
+ 1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
+ 2. Creates vector indexes for each (type, field) combination on first encounter
+ 3. Batches points per (type, field) and creates async indexing tasks
+ 4. Executes all indexing tasks in parallel for efficient embedding generation
+
+ Args:
+ data_points: List of DataPoint objects to index. Each DataPoint's metadata must
+ contain an 'index_fields' list specifying which fields to embed.
+
+ Returns:
+ The original data_points list.
+ """
+ data_points_by_type = {}
vector_engine = get_vector_engine()
for data_point in data_points:
data_point_type = type(data_point)
+ type_name = data_point_type.__name__
for field_name in data_point.metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue
- index_name = f"{data_point_type.__name__}_{field_name}"
+ if type_name not in data_points_by_type:
+ data_points_by_type[type_name] = {}
- if index_name not in created_indexes:
- await vector_engine.create_vector_index(data_point_type.__name__, field_name)
- created_indexes[index_name] = True
-
- if index_name not in index_points:
- index_points[index_name] = []
+ if field_name not in data_points_by_type[type_name]:
+ await vector_engine.create_vector_index(type_name, field_name)
+ data_points_by_type[type_name][field_name] = []
indexed_data_point = data_point.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
- index_points[index_name].append(indexed_data_point)
+ data_points_by_type[type_name][field_name].append(indexed_data_point)
- tasks: list[asyncio.Task] = []
batch_size = vector_engine.embedding_engine.get_batch_size()
- for index_name_and_field, points in index_points.items():
- first = index_name_and_field.index("_")
- index_name = index_name_and_field[:first]
- field_name = index_name_and_field[first + 1 :]
+ batches = (
+ (type_name, field_name, points[i : i + batch_size])
+ for type_name, fields in data_points_by_type.items()
+ for field_name, points in fields.items()
+ for i in range(0, len(points), batch_size)
+ )
- # Create embedding requests per batch to run in parallel later
- for i in range(0, len(points), batch_size):
- batch = points[i : i + batch_size]
- tasks.append(
- asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
- )
+ tasks = [
+ asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
+ for type_name, field_name, batch_points in batches
+ ]
- # Run all embedding requests in parallel
await asyncio.gather(*tasks)
return data_points
diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py
index 4fa8cfc75..03b5a25a5 100644
--- a/cognee/tasks/storage/index_graph_edges.py
+++ b/cognee/tasks/storage/index_graph_edges.py
@@ -1,17 +1,44 @@
-import asyncio
+from collections import Counter
+from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger
-from collections import Counter
-from typing import Optional, Dict, Any, List, Tuple, Union
-from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
+from cognee.tasks.storage.index_data_points import index_data_points
logger = get_logger()
+def _get_edge_text(item: dict) -> str:
+ """Extract edge text for embedding - prefers edge_text field with fallback."""
+ if "edge_text" in item:
+ return item["edge_text"]
+
+ if "relationship_name" in item:
+ return item["relationship_name"]
+
+ return ""
+
+
+def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
+ """Transform raw edge data into EdgeType datapoints."""
+ edge_texts = [
+ _get_edge_text(item)
+ for edge in edges_data
+ for item in edge
+ if isinstance(item, dict) and "relationship_name" in item
+ ]
+
+ edge_types = Counter(edge_texts)
+
+ return [
+ EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
+ for text, count in edge_types.items()
+ ]
+
+
async def index_graph_edges(
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
):
@@ -23,24 +50,17 @@ async def index_graph_edges(
the `relationship_name` field.
Steps:
- 1. Initialize the vector engine and graph engine.
- 2. Retrieve graph edge data and count relationship types (`relationship_name`).
- 3. Create vector indexes for `relationship_name` if they don't exist.
- 4. Transform the counted relationships into `EdgeType` objects.
- 5. Index the transformed data points in the vector engine.
+ 1. Initialize the graph engine if needed and retrieve edge data.
+ 2. Transform edge data into EdgeType datapoints.
+ 3. Index the EdgeType datapoints using the standard indexing function.
Raises:
- RuntimeError: If initialization of the vector engine or graph engine fails.
+ RuntimeError: If initialization of the graph engine fails.
Returns:
None
"""
try:
- created_indexes = {}
- index_points = {}
-
- vector_engine = get_vector_engine()
-
if edges_data is None:
graph_engine = await get_graph_engine()
_, edges_data = await graph_engine.get_graph_data()
@@ -51,47 +71,7 @@ async def index_graph_edges(
logger.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
- edge_types = Counter(
- item.get("relationship_name")
- for edge in edges_data
- for item in edge
- if isinstance(item, dict) and "relationship_name" in item
- )
-
- for text, count in edge_types.items():
- edge = EdgeType(
- id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
- )
- data_point_type = type(edge)
-
- for field_name in edge.metadata["index_fields"]:
- index_name = f"{data_point_type.__name__}.{field_name}"
-
- if index_name not in created_indexes:
- await vector_engine.create_vector_index(data_point_type.__name__, field_name)
- created_indexes[index_name] = True
-
- if index_name not in index_points:
- index_points[index_name] = []
-
- indexed_data_point = edge.model_copy()
- indexed_data_point.metadata["index_fields"] = [field_name]
- index_points[index_name].append(indexed_data_point)
-
- # Get maximum batch size for embedding model
- batch_size = vector_engine.embedding_engine.get_batch_size()
- tasks: list[asyncio.Task] = []
-
- for index_name, indexable_points in index_points.items():
- index_name, field_name = index_name.split(".")
-
- # Create embedding tasks to run in parallel later
- for start in range(0, len(indexable_points), batch_size):
- batch = indexable_points[start : start + batch_size]
-
- tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
-
- # Start all embedding tasks and wait for completion
- await asyncio.gather(*tasks)
+ edge_type_datapoints = create_edge_type_datapoints(edges_data)
+ await index_data_points(edge_type_datapoints)
return None
diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py
new file mode 100644
index 000000000..421bb81bd
--- /dev/null
+++ b/cognee/tests/integration/documents/CsvDocument_test.py
@@ -0,0 +1,70 @@
+import os
+import sys
+import uuid
+import pytest
+import pathlib
+from unittest.mock import patch
+
+from cognee.modules.chunking.CsvChunker import CsvChunker
+from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
+from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
+from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
+
+chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
+
+
+GROUND_TRUTH = {
+ "chunk_size_10": [
+ {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
+ {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
+ {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
+ {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
+ ],
+ "chunk_size_128": [
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
+ ],
+}
+
+
+@pytest.mark.parametrize(
+ "input_file,chunk_size",
+ [("example_with_header.csv", 10), ("example_with_header.csv", 128)],
+)
+@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
+@pytest.mark.asyncio
+async def test_CsvDocument(mock_engine, input_file, chunk_size):
+ # Define file paths of test data
+ csv_file_path = os.path.join(
+ pathlib.Path(__file__).parent.parent.parent,
+ "test_data",
+ input_file,
+ )
+
+ # Define test documents
+ csv_document = CsvDocument(
+ id=uuid.uuid4(),
+ name="example_with_header.csv",
+ raw_data_location=csv_file_path,
+ external_metadata="",
+ mime_type="text/csv",
+ )
+
+ # TEST CSV
+ ground_truth_key = f"chunk_size_{chunk_size}"
+ async for ground_truth, row_data in async_gen_zip(
+ GROUND_TRUTH[ground_truth_key],
+ csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
+ ):
+ assert ground_truth["token_count"] == row_data.chunk_size, (
+ f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
+ )
+ assert ground_truth["len_text"] == len(row_data.text), (
+ f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
+ )
+ assert ground_truth["cut_type"] == row_data.cut_type, (
+ f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
+ )
+ assert ground_truth["chunk_index"] == row_data.chunk_index, (
+ f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
+ )
diff --git a/cognee/tests/tasks/entity_extraction/entity_extraction_test.py b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py
index 39e883e09..41a9254ca 100644
--- a/cognee/tests/tasks/entity_extraction/entity_extraction_test.py
+++ b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py
@@ -55,7 +55,7 @@ async def main():
classified_data = ingestion.classify(file)
# data_id is the hash of original file contents + owner id to avoid duplicate data
- data_id = ingestion.identify(classified_data, await get_default_user())
+ data_id = await ingestion.identify(classified_data, await get_default_user())
await cognee.add(file_path)
diff --git a/cognee/tests/test_add_docling_document.py b/cognee/tests/test_add_docling_document.py
index 2c82af66f..c5aa4e9d1 100644
--- a/cognee/tests/test_add_docling_document.py
+++ b/cognee/tests/test_add_docling_document.py
@@ -39,12 +39,12 @@ async def main():
answer = await cognee.search("Do programmers change light bulbs?")
assert len(answer) != 0
- lowercase_answer = answer[0].lower()
+ lowercase_answer = answer[0]["search_result"][0].lower()
assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
answer = await cognee.search("What colours are there in the presentation table?")
assert len(answer) != 0
- lowercase_answer = answer[0].lower()
+ lowercase_answer = answer[0]["search_result"][0].lower()
assert (
("red" in lowercase_answer)
and ("blue" in lowercase_answer)
diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py
index ab68a8ef1..ddffe53a4 100644
--- a/cognee/tests/test_cognee_server_start.py
+++ b/cognee/tests/test_cognee_server_start.py
@@ -7,6 +7,7 @@ import requests
from pathlib import Path
import sys
import uuid
+import json
class TestCogneeServerStart(unittest.TestCase):
@@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase):
)
}
- payload = {"datasets": [dataset_name]}
+ ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]}
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
if add_response.status_code not in [200, 201]:
add_response.raise_for_status()
+ ontology_content = b"""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ A failure caused by physical components.
+
+
+
+
+ An error caused by software logic or configuration.
+
+
+
+ A human being or individual.
+
+
+
+
+ Programmers
+
+
+
+
+
+ Hardware Problem
+
+
+ """
+
+ ontology_response = requests.post(
+ "http://127.0.0.1:8000/api/v1/ontologies",
+ headers=headers,
+ files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
+ data={
+ "ontology_key": json.dumps([ontology_key]),
+ "description": json.dumps(["Test ontology"]),
+ },
+ )
+ self.assertEqual(ontology_response.status_code, 200)
+
# Cognify request
url = "http://127.0.0.1:8000/api/v1/cognify"
headers = {
@@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase):
if cognify_response.status_code not in [200, 201]:
cognify_response.raise_for_status()
+ datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers)
+
+ datasets = datasets_response.json()
+ dataset_id = None
+ for dataset in datasets:
+ if dataset["name"] == dataset_name:
+ dataset_id = dataset["id"]
+ break
+
+ graph_response = requests.get(
+ f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers
+ )
+ self.assertEqual(graph_response.status_code, 200)
+
+ graph_data = graph_response.json()
+ ontology_nodes = [
+ node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid")
+ ]
+
+ self.assertGreater(
+ len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated"
+ )
+
# TODO: Add test to verify cognify pipeline is complete before testing search
# Search request
diff --git a/cognee/tests/test_conversation_history.py b/cognee/tests/test_conversation_history.py
index 30bb54ef1..783baf563 100644
--- a/cognee/tests/test_conversation_history.py
+++ b/cognee/tests/test_conversation_history.py
@@ -16,9 +16,11 @@ import cognee
import pathlib
from cognee.infrastructure.databases.cache import get_cache_engine
+from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
+from collections import Counter
logger = get_logger()
@@ -54,10 +56,10 @@ async def main():
"""DataCo is a data analytics company. They help businesses make sense of their data."""
)
- await cognee.add(text_1, dataset_name)
- await cognee.add(text_2, dataset_name)
+ await cognee.add(data=text_1, dataset_name=dataset_name)
+ await cognee.add(data=text_2, dataset_name=dataset_name)
- await cognee.cognify([dataset_name])
+ await cognee.cognify(datasets=[dataset_name])
user = await get_default_user()
@@ -188,7 +190,6 @@ async def main():
f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}"
)
- # Verify saved
history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10)
our_qa_summary = [
h for h in history_summary if h["question"] == "What are the key points about TechCorp?"
@@ -228,6 +229,46 @@ async def main():
assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix"
assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix"
+ from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
+ persist_sessions_in_knowledge_graph_pipeline,
+ )
+
+ logger.info("Starting persist_sessions_in_knowledge_graph tests")
+
+ await persist_sessions_in_knowledge_graph_pipeline(
+ user=user,
+ session_ids=[session_id_1, session_id_2],
+ dataset=dataset_name,
+ run_in_background=False,
+ )
+
+ graph_engine = await get_graph_engine()
+ graph = await graph_engine.get_graph_data()
+
+ type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
+
+ "Tests the correct number of NodeSet nodes after session persistence"
+ assert type_counts.get("NodeSet", 0) == 1, (
+ f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1."
+ )
+
+ "Tests the correct number of DocumentChunk nodes after session persistence"
+ assert type_counts.get("DocumentChunk", 0) == 4, (
+ f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)."
+ )
+
+ from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
+
+ vector_engine = get_vector_engine()
+ collection_size = await vector_engine.search(
+ collection_name="DocumentChunk_text",
+ query_text="test",
+ limit=1000,
+ )
+ assert len(collection_size) == 4, (
+ f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}"
+ )
+
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv
new file mode 100644
index 000000000..dc900e5ef
--- /dev/null
+++ b/cognee/tests/test_data/example_with_header.csv
@@ -0,0 +1,3 @@
+id,name,age,city,country
+1,Eric,30,Beijing,China
+2,Joe,35,Berlin,Germany
diff --git a/cognee/tests/test_edge_ingestion.py b/cognee/tests/test_edge_ingestion.py
index 5b23f7819..0d1407fab 100755
--- a/cognee/tests/test_edge_ingestion.py
+++ b/cognee/tests/test_edge_ingestion.py
@@ -52,6 +52,33 @@ async def test_edge_ingestion():
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
+ "Tests edge_text presence and format"
+ contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
+ assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
+
+ edge_properties = contains_edges[0][3]
+ assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
+
+ edge_text = edge_properties["edge_text"]
+ assert "relationship_name: contains" in edge_text, (
+ f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
+ )
+ assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}"
+ assert "entity_description:" in edge_text, (
+ f"Expected 'entity_description:' in edge_text, got: {edge_text}"
+ )
+
+ all_edge_texts = [
+ edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3]
+ ]
+ expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"]
+ found_entity = any(
+ any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts
+ )
+ assert found_entity, (
+ f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}"
+ )
+
"Tests the presence of basic nested edges"
for basic_nested_edge in basic_nested_edges:
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (
diff --git a/cognee/tests/test_feedback_enrichment.py b/cognee/tests/test_feedback_enrichment.py
index 02d90db32..378cb0e45 100644
--- a/cognee/tests/test_feedback_enrichment.py
+++ b/cognee/tests/test_feedback_enrichment.py
@@ -133,7 +133,7 @@ async def main():
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}],
- dataset="feedback_enrichment_test_memify",
+ dataset=dataset_name,
)
nodes_after, edges_after = await graph_engine.get_graph_data()
diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py
index 81f81ee61..893b836c0 100755
--- a/cognee/tests/test_library.py
+++ b/cognee/tests/test_library.py
@@ -90,15 +90,17 @@ async def main():
)
search_results = await cognee.search(
- query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?"
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What information do you contain?",
+ dataset_ids=[pipeline_run_obj.dataset_id],
)
- assert "Mark" in search_results[0], (
+ assert "Mark" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Mark in search results"
)
- assert "Cindy" in search_results[0], (
+ assert "Cindy" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Cindy in search results"
)
- assert "Artificial intelligence" not in search_results[0], (
+ assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
"Failed to update document, Artificial intelligence still mentioned in search results"
)
diff --git a/cognee/tests/test_load.py b/cognee/tests/test_load.py
new file mode 100644
index 000000000..b38466bc7
--- /dev/null
+++ b/cognee/tests/test_load.py
@@ -0,0 +1,62 @@
+import os
+import pathlib
+import asyncio
+import time
+
+import cognee
+from cognee.modules.search.types import SearchType
+from cognee.shared.logging_utils import get_logger
+
+logger = get_logger()
+
+
+async def process_and_search(num_of_searches):
+ start_time = time.time()
+
+ await cognee.cognify()
+
+ await asyncio.gather(
+ *[
+ cognee.search(
+ query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION
+ )
+ for _ in range(num_of_searches)
+ ]
+ )
+
+ end_time = time.time()
+
+ return end_time - start_time
+
+
+async def main():
+ data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load")
+ cognee.config.data_root_directory(data_directory_path)
+
+ cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load")
+ cognee.config.system_root_directory(cognee_directory_path)
+
+ num_of_pdfs = 10
+ num_of_reps = 5
+ upper_boundary_minutes = 10
+ average_minutes = 8
+
+ recorded_times = []
+ for _ in range(num_of_reps):
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ s3_input = "s3://cognee-test-load-s3-bucket"
+ await cognee.add(s3_input)
+
+ recorded_times.append(await process_and_search(num_of_pdfs))
+
+ average_recorded_time = sum(recorded_times) / len(recorded_times)
+
+ assert average_recorded_time <= average_minutes * 60
+
+ assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/cognee/tests/test_multi_tenancy.py b/cognee/tests/test_multi_tenancy.py
new file mode 100644
index 000000000..7cdcda8d8
--- /dev/null
+++ b/cognee/tests/test_multi_tenancy.py
@@ -0,0 +1,165 @@
+import cognee
+import pytest
+
+from cognee.modules.users.exceptions import PermissionDeniedError
+from cognee.modules.users.tenants.methods import select_tenant
+from cognee.modules.users.methods import get_user
+from cognee.shared.logging_utils import get_logger
+from cognee.modules.search.types import SearchType
+from cognee.modules.users.methods import create_user
+from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
+from cognee.modules.users.roles.methods import add_user_to_role
+from cognee.modules.users.roles.methods import create_role
+from cognee.modules.users.tenants.methods import create_tenant
+from cognee.modules.users.tenants.methods import add_user_to_tenant
+from cognee.modules.engine.operations.setup import setup
+from cognee.shared.logging_utils import setup_logging, CRITICAL
+
+logger = get_logger()
+
+
+async def main():
+ # Create a clean slate for cognee -- reset data and system state
+ print("Resetting cognee data...")
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ print("Data reset complete.\n")
+
+ # Set up the necessary databases and tables for user management.
+ await setup()
+
+ # Add document for user_1, add it under dataset name AI
+ text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
+ At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages
+ this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the
+ preparation and manipulation of quantum state"""
+
+ print("Creating user_1: user_1@example.com")
+ user_1 = await create_user("user_1@example.com", "example")
+ await cognee.add([text], dataset_name="AI", user=user_1)
+
+ print("\nCreating user_2: user_2@example.com")
+ user_2 = await create_user("user_2@example.com", "example")
+
+ # Run cognify for both datasets as the appropriate user/owner
+ print("\nCreating different datasets for user_1 (AI dataset) and user_2 (QUANTUM dataset)")
+ ai_cognify_result = await cognee.cognify(["AI"], user=user_1)
+
+ # Extract dataset_ids from cognify results
+ def extract_dataset_id_from_cognify(cognify_result):
+ """Extract dataset_id from cognify output dictionary"""
+ for dataset_id, pipeline_result in cognify_result.items():
+ return dataset_id # Return the first dataset_id
+ return None
+
+ # Get dataset IDs from cognify results
+ # Note: When we want to work with datasets from other users (search, add, cognify and etc.) we must supply dataset
+ # information through dataset_id using dataset name only looks for datasets owned by current user
+ ai_dataset_id = extract_dataset_id_from_cognify(ai_cognify_result)
+
+ # We can see here that user_1 can read his own dataset (AI dataset)
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What is in the document?",
+ user=user_1,
+ datasets=[ai_dataset_id],
+ )
+
+ # Verify that user_2 cannot access user_1's dataset without permission
+ with pytest.raises(PermissionDeniedError):
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What is in the document?",
+ user=user_2,
+ datasets=[ai_dataset_id],
+ )
+
+ # Create new tenant and role, add user_2 to tenant and role
+ tenant_id = await create_tenant("CogneeLab", user_1.id)
+ await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
+ role_id = await create_role(role_name="Researcher", owner_id=user_1.id)
+ await add_user_to_tenant(
+ user_id=user_2.id, tenant_id=tenant_id, owner_id=user_1.id, set_as_active_tenant=True
+ )
+ await add_user_to_role(user_id=user_2.id, role_id=role_id, owner_id=user_1.id)
+
+ # Assert that user_1 cannot give permissions on his dataset to role before switching to the correct tenant
+ # AI dataset was made with default tenant and not CogneeLab tenant
+ with pytest.raises(PermissionDeniedError):
+ await authorized_give_permission_on_datasets(
+ role_id,
+ [ai_dataset_id],
+ "read",
+ user_1.id,
+ )
+
+ # We need to refresh the user object with changes made when switching tenants
+ user_1 = await get_user(user_1.id)
+ await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
+ ai_cognee_lab_cognify_result = await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
+
+ ai_cognee_lab_dataset_id = extract_dataset_id_from_cognify(ai_cognee_lab_cognify_result)
+
+ await authorized_give_permission_on_datasets(
+ role_id,
+ [ai_cognee_lab_dataset_id],
+ "read",
+ user_1.id,
+ )
+
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What is in the document?",
+ user=user_2,
+ dataset_ids=[ai_cognee_lab_dataset_id],
+ )
+ for result in search_results:
+ print(f"{result}\n")
+
+ # Let's test changing tenants
+ tenant_id = await create_tenant("CogneeLab2", user_1.id)
+ await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
+
+ user_1 = await get_user(user_1.id)
+ await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
+ await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
+
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What is in the document?",
+ user=user_1,
+ )
+
+ # Assert only AI_COGNEE_LAB dataset from CogneeLab2 tenant is visible as the currently selected tenant
+ assert len(search_results) == 1, (
+ f"Search results must only contain one dataset from current tenant: {search_results}"
+ )
+ assert search_results[0]["dataset_name"] == "AI_COGNEE_LAB", (
+ f"Dict must contain dataset name 'AI_COGNEE_LAB': {search_results[0]}"
+ )
+ assert search_results[0]["dataset_tenant_id"] == user_1.tenant_id, (
+ f"Dataset tenant_id must be same as user_1 tenant_id: {search_results[0]}"
+ )
+
+ # Switch back to no tenant (default tenant)
+ await select_tenant(user_id=user_1.id, tenant_id=None)
+ # Refresh user_1 object
+ user_1 = await get_user(user_1.id)
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text="What is in the document?",
+ user=user_1,
+ )
+ assert len(search_results) == 1, (
+ f"Search results must only contain one dataset from default tenant: {search_results}"
+ )
+ assert search_results[0]["dataset_name"] == "AI", (
+ f"Dict must contain dataset name 'AI': {search_results[0]}"
+ )
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ logger = setup_logging(log_level=CRITICAL)
+ asyncio.run(main())
diff --git a/cognee/tests/test_parallel_databases.py b/cognee/tests/test_parallel_databases.py
index 9a590921a..3164206ed 100755
--- a/cognee/tests/test_parallel_databases.py
+++ b/cognee/tests/test_parallel_databases.py
@@ -33,11 +33,13 @@ async def main():
"vector_db_url": "cognee1.test",
"vector_db_key": "",
"vector_db_provider": "lancedb",
+ "vector_db_name": "",
}
task_2_config = {
"vector_db_url": "cognee2.test",
"vector_db_key": "",
"vector_db_provider": "lancedb",
+ "vector_db_name": "",
}
task_1_graph_config = {
diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py
index 2b69ce854..ae06e7c5d 100644
--- a/cognee/tests/test_relational_db_migration.py
+++ b/cognee/tests/test_relational_db_migration.py
@@ -1,6 +1,5 @@
import pathlib
import os
-from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
@@ -10,7 +9,7 @@ from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
from cognee.tasks.ingestion import migrate_relational_database
-from cognee.modules.search.types import SearchResult, SearchType
+from cognee.modules.search.types import SearchType
import cognee
@@ -27,6 +26,9 @@ def normalize_node_name(node_name: str) -> str:
async def setup_test_db():
+ # Disable backend access control to migrate relational data
+ os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
+
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@@ -271,6 +273,55 @@ async def test_schema_only_migration():
print(f"Edge counts: {edge_counts}")
+async def test_search_result_quality():
+ from cognee.infrastructure.databases.relational import (
+ get_migration_relational_engine,
+ )
+
+ # Get relational database with original data
+ migration_engine = get_migration_relational_engine()
+ from sqlalchemy import text
+
+ async with migration_engine.engine.connect() as conn:
+ result = await conn.execute(
+ text("""
+ SELECT
+ c.CustomerId,
+ c.FirstName,
+ c.LastName,
+ GROUP_CONCAT(i.InvoiceId, ',') AS invoice_ids
+ FROM Customer AS c
+ LEFT JOIN Invoice AS i ON c.CustomerId = i.CustomerId
+ GROUP BY c.CustomerId, c.FirstName, c.LastName
+ """)
+ )
+
+ for row in result:
+ # Get expected invoice IDs from relational DB for each Customer
+ customer_id = row.CustomerId
+ invoice_ids = row.invoice_ids.split(",") if row.invoice_ids else []
+ print(f"Relational DB Customer {customer_id}: {invoice_ids}")
+
+ # Use Cognee search to get invoice IDs for the same Customer but by providing Customer name
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text=f"List me all the invoices of Customer:{row.FirstName} {row.LastName}.",
+ top_k=50,
+ system_prompt="Just return me the invoiceID as a number without any text. This is an example output: ['1', '2', '3']. Where 1, 2, 3 are invoiceIDs of an invoice",
+ )
+ print(f"Cognee search result: {search_results}")
+
+ import ast
+
+ lst = ast.literal_eval(search_results[0]) # converts string -> Python list
+ # Transfrom both lists to int for comparison, sorting and type consistency
+ lst = sorted([int(x) for x in lst])
+ invoice_ids = sorted([int(x) for x in invoice_ids])
+ assert lst == invoice_ids, (
+ f"Search results {lst} do not match expected invoice IDs {invoice_ids} for Customer:{customer_id}"
+ )
+
+
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
@@ -283,6 +334,7 @@ async def test_migration_sqlite():
)
await relational_db_migration()
+ await test_search_result_quality()
await test_schema_only_migration()
diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py
index e24abd0f5..bd11dc62e 100644
--- a/cognee/tests/test_search_db.py
+++ b/cognee/tests/test_search_db.py
@@ -146,7 +146,13 @@ async def main():
assert len(search_results) == 1, (
f"{name}: expected single-element list, got {len(search_results)}"
)
- text = search_results[0]
+
+ from cognee.context_global_variables import backend_access_control_enabled
+
+ if backend_access_control_enabled():
+ text = search_results[0]["search_result"][0]
+ else:
+ text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
diff --git a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py
index 2eabee91a..6cc37ef38 100644
--- a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py
+++ b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py
@@ -1,3 +1,4 @@
+import os
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from uuid import uuid4
@@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
from types import SimpleNamespace
import importlib
-from cognee.api.client import app
-
# Fixtures for reuse across test classes
@pytest.fixture
@@ -32,6 +31,10 @@ def mock_authenticated_user():
)
+# To turn off authentication we need to set the environment variable before importing the module
+# Also both require_authentication and backend access control must be false
+os.environ["REQUIRE_AUTHENTICATION"] = "false"
+os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
@pytest.fixture
def client(self):
+ from cognee.api.client import app
+
"""Create a test client."""
return TestClient(app)
@@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
@pytest.fixture
def client(self):
+ from cognee.api.client import app
+
return TestClient(app)
@pytest.mark.parametrize(
@@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
@pytest.fixture
def client(self):
+ from cognee.api.client import app
+
return TestClient(app)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
@@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
# The exact error message may vary depending on the actual database connection
# The important thing is that we get a 500 error when user creation fails
- def test_current_environment_configuration(self):
+ def test_current_environment_configuration(self, client):
"""Test that current environment configuration is working properly."""
# This tests the actual module state without trying to change it
from cognee.modules.users.methods.get_authenticated_user import (
diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py
new file mode 100644
index 000000000..af3a4d90e
--- /dev/null
+++ b/cognee/tests/unit/api/test_ontology_endpoint.py
@@ -0,0 +1,272 @@
+import pytest
+import uuid
+from fastapi.testclient import TestClient
+from unittest.mock import patch, Mock, AsyncMock
+from types import SimpleNamespace
+import importlib
+from cognee.api.client import app
+
+gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
+
+
+@pytest.fixture
+def client():
+ return TestClient(app)
+
+
+@pytest.fixture
+def mock_user():
+ user = Mock()
+ user.id = "test-user-123"
+ return user
+
+
+@pytest.fixture
+def mock_default_user():
+ """Mock default user for testing."""
+ return SimpleNamespace(
+ id=str(uuid.uuid4()),
+ email="default@example.com",
+ is_active=True,
+ tenant_id=str(uuid.uuid4()),
+ )
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
+ """Test successful ontology upload"""
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ ontology_content = (
+ b""
+ )
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+
+ response = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
+ data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
+ assert "uploaded_at" in data["uploaded_ontologies"][0]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
+ """Test 400 response for non-.owl files"""
+ mock_get_default_user.return_value = mock_default_user
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ response = client.post(
+ "/api/v1/ontologies",
+ files={"ontology_file": ("test.txt", b"not xml")},
+ data={"ontology_key": unique_key},
+ )
+ assert response.status_code == 400
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
+ """Test 400 response for missing file or key"""
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ # Missing file
+ response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
+ assert response.status_code == 400
+
+ # Missing key
+ response = client.post(
+ "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))]
+ )
+ assert response.status_code == 400
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
+ """Test behavior when default user is provided (no explicit authentication)"""
+ import json
+
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ mock_get_default_user.return_value = mock_default_user
+ response = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test.owl", b"", "application/xml"))],
+ data={"ontology_key": json.dumps([unique_key])},
+ )
+
+ # The current system provides a default user when no explicit authentication is given
+ # This test verifies the system works with conditional authentication
+ assert response.status_code == 200
+ data = response.json()
+ assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
+ assert "uploaded_at" in data["uploaded_ontologies"][0]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
+ """Test uploading multiple ontology files in single request"""
+ import io
+
+ mock_get_default_user.return_value = mock_default_user
+ # Create mock files
+ file1_content = b""
+ file2_content = b""
+
+ files = [
+ ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
+ ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": '["vehicles", "manufacturers"]',
+ "descriptions": '["Base vehicles", "Car manufacturers"]',
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+
+ assert response.status_code == 200
+ result = response.json()
+ assert "uploaded_ontologies" in result
+ assert len(result["uploaded_ontologies"]) == 2
+ assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
+ assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
+ """Test that upload endpoint accepts array parameters"""
+ import io
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ file_content = b""
+
+ files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
+ data = {
+ "ontology_key": json.dumps(["single_key"]),
+ "descriptions": json.dumps(["Single ontology"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+
+ assert response.status_code == 200
+ result = response.json()
+ assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
+ """Test cognify endpoint accepts multiple ontology keys"""
+ payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["ontology1", "ontology2"], # Array instead of string
+ "run_in_background": False,
+ }
+
+ response = client.post("/api/v1/cognify", json=payload)
+
+ # Should not fail due to ontology_key type
+ assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
+ """Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
+ import io
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ # Step 1: Upload multiple ontologies
+ file1_content = b"""
+
+
+ """
+
+ file2_content = b"""
+
+
+ """
+
+ files = [
+ ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
+ ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": json.dumps(["vehicles", "manufacturers"]),
+ "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
+ }
+
+ upload_response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert upload_response.status_code == 200
+
+ # Step 2: Verify ontologies are listed
+ list_response = client.get("/api/v1/ontologies")
+ assert list_response.status_code == 200
+ ontologies = list_response.json()
+ assert "vehicles" in ontologies
+ assert "manufacturers" in ontologies
+
+ # Step 3: Test cognify with multiple ontologies
+ cognify_payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["vehicles", "manufacturers"],
+ "run_in_background": False,
+ }
+
+ cognify_response = client.post("/api/v1/cognify", json=cognify_payload)
+ # Should not fail due to ontology handling (may fail for dataset reasons)
+ assert cognify_response.status_code != 400 # Not a validation error
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
+ """Test error handling for invalid multifile uploads"""
+ import io
+ import json
+
+ # Test mismatched array lengths
+ file_content = b""
+ files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
+ data = {
+ "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
+ "descriptions": json.dumps(["desc1"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert response.status_code == 400
+ assert "Number of keys must match number of files" in response.json()["error"]
+
+ # Test duplicate keys
+ files = [
+ ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
+ ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": json.dumps(["duplicate", "duplicate"]),
+ "descriptions": json.dumps(["desc1", "desc2"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert response.status_code == 400
+ assert "Duplicate ontology keys not allowed" in response.json()["error"]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
+ """Test cognify with non-existent ontology key"""
+ mock_get_default_user.return_value = mock_default_user
+
+ payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["nonexistent_key"],
+ "run_in_background": False,
+ }
+
+ response = client.post("/api/v1/cognify", json=payload)
+ assert response.status_code == 409
+ assert "Ontology key 'nonexistent_key' not found" in response.json()["error"]
diff --git a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py
index a8d3bda82..837a9955c 100644
--- a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py
+++ b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py
@@ -8,6 +8,7 @@ def test_cache_config_defaults():
"""Test that CacheConfig has the correct default values."""
config = CacheConfig()
+ assert config.cache_backend == "fs"
assert config.caching is False
assert config.shared_kuzu_lock is False
assert config.cache_host == "localhost"
@@ -19,6 +20,7 @@ def test_cache_config_defaults():
def test_cache_config_custom_values():
"""Test that CacheConfig accepts custom values."""
config = CacheConfig(
+ cache_backend="redis",
caching=True,
shared_kuzu_lock=True,
cache_host="redis.example.com",
@@ -27,6 +29,7 @@ def test_cache_config_custom_values():
agentic_lock_timeout=180,
)
+ assert config.cache_backend == "redis"
assert config.caching is True
assert config.shared_kuzu_lock is True
assert config.cache_host == "redis.example.com"
@@ -38,6 +41,7 @@ def test_cache_config_custom_values():
def test_cache_config_to_dict():
"""Test the to_dict method returns all configuration values."""
config = CacheConfig(
+ cache_backend="fs",
caching=True,
shared_kuzu_lock=True,
cache_host="test-host",
@@ -49,6 +53,7 @@ def test_cache_config_to_dict():
config_dict = config.to_dict()
assert config_dict == {
+ "cache_backend": "fs",
"caching": True,
"shared_kuzu_lock": True,
"cache_host": "test-host",
diff --git a/cognee/tests/unit/infrastructure/databases/test_index_data_points.py b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py
new file mode 100644
index 000000000..21a5695de
--- /dev/null
+++ b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py
@@ -0,0 +1,27 @@
+import pytest
+from unittest.mock import AsyncMock, patch, MagicMock
+from cognee.tasks.storage.index_data_points import index_data_points
+from cognee.infrastructure.engine import DataPoint
+
+
+class TestDataPoint(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+
+@pytest.mark.asyncio
+async def test_index_data_points_calls_vector_engine():
+ """Test that index_data_points creates vector index and indexes data."""
+ data_points = [TestDataPoint(name="test1")]
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
+
+ with patch.dict(
+ index_data_points.__globals__,
+ {"get_vector_engine": lambda: mock_vector_engine},
+ ):
+ await index_data_points(data_points)
+
+ assert mock_vector_engine.create_vector_index.await_count >= 1
+ assert mock_vector_engine.index_data_points.await_count >= 1
diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py
index 48bbc53e3..cee0896c2 100644
--- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py
+++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py
@@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges
@pytest.mark.asyncio
async def test_index_graph_edges_success():
- """Test that index_graph_edges uses the index datapoints and creates vector index."""
- # Create the mocks for the graph and vector engines.
+ """Test that index_graph_edges retrieves edges and delegates to index_data_points."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (
None,
@@ -15,26 +14,23 @@ async def test_index_graph_edges_success():
[{"relationship_name": "rel2"}],
],
)
- mock_vector_engine = AsyncMock()
- mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
+ mock_index_data_points = AsyncMock()
- # Patch the globals of the function so that when it does:
- # vector_engine = get_vector_engine()
- # graph_engine = await get_graph_engine()
- # it uses the mocked versions.
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
- "get_vector_engine": lambda: mock_vector_engine,
+ "index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
- # Assertions on the mock calls.
mock_graph_engine.get_graph_data.assert_awaited_once()
- assert mock_vector_engine.create_vector_index.await_count == 1
- assert mock_vector_engine.index_data_points.await_count == 1
+ mock_index_data_points.assert_awaited_once()
+
+ call_args = mock_index_data_points.call_args[0][0]
+ assert len(call_args) == 2
+ assert all(hasattr(item, "relationship_name") for item in call_args)
@pytest.mark.asyncio
@@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships():
"""Test that index_graph_edges handles empty relationships correctly."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, [])
- mock_vector_engine = AsyncMock()
+ mock_index_data_points = AsyncMock()
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
- "get_vector_engine": lambda: mock_vector_engine,
+ "index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once()
- mock_vector_engine.create_vector_index.assert_not_awaited()
- mock_vector_engine.index_data_points.assert_not_awaited()
+ mock_index_data_points.assert_awaited_once()
+
+ call_args = mock_index_data_points.call_args[0][0]
+ assert len(call_args) == 0
@pytest.mark.asyncio
diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker.py b/cognee/tests/unit/modules/chunking/test_text_chunker.py
new file mode 100644
index 000000000..d535f74b0
--- /dev/null
+++ b/cognee/tests/unit/modules/chunking/test_text_chunker.py
@@ -0,0 +1,248 @@
+"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence."""
+
+import pytest
+from uuid import uuid4
+
+from cognee.modules.chunking.TextChunker import TextChunker
+from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
+from cognee.modules.data.processing.document_types import Document
+
+
+@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"])
+def chunker_class(request):
+ """Parametrize tests to run against both implementations."""
+ return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap
+
+
+@pytest.fixture
+def make_text_generator():
+ """Factory for async text generators."""
+
+ def _factory(*texts):
+ async def gen():
+ for text in texts:
+ yield text
+
+ return gen
+
+ return _factory
+
+
+async def collect_chunks(chunker):
+ """Consume async generator and return list of chunks."""
+ chunks = []
+ async for chunk in chunker.read():
+ chunks.append(chunk)
+ return chunks
+
+
+@pytest.mark.asyncio
+async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator):
+ """Empty input should yield no chunks."""
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator("")
+ chunker = chunker_class(document, get_text, max_chunk_size=512)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 0, "Empty input should produce no chunks"
+
+
+@pytest.mark.asyncio
+async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator):
+ """Whitespace-only input should produce exactly one chunk with unchanged text."""
+ whitespace_text = " \n\t \r\n "
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(whitespace_text)
+ chunker = chunker_class(document, get_text, max_chunk_size=512)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk"
+ assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)"
+ assert chunks[0].chunk_index == 0, "First chunk should have index 0"
+
+
+@pytest.mark.asyncio
+async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator):
+ """Single paragraph below limit should emit exactly one chunk."""
+ text = "This is a short paragraph."
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ chunker = chunker_class(document, get_text, max_chunk_size=512)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk"
+ assert chunks[0].text == text, "Chunk text should match input"
+ assert chunks[0].chunk_index == 0, "First chunk should have index 0"
+ assert chunks[0].chunk_size > 0, "Chunk should have positive size"
+
+
+@pytest.mark.asyncio
+async def test_oversized_paragraph_gets_emitted_as_a_single_chunk(
+ chunker_class, make_text_generator
+):
+ """Oversized paragraph from chunk_by_paragraph should be emitted as single chunk."""
+ text = ("A" * 1500) + ". Next sentence."
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ chunker = chunker_class(document, get_text, max_chunk_size=50)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)"
+ assert chunks[0].chunk_size > 50, "First chunk should be oversized"
+ assert chunks[0].chunk_index == 0, "First chunk should have index 0"
+ assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
+
+
+@pytest.mark.asyncio
+async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator):
+ """First paragraph near limit plus small paragraph should produce two separate chunks."""
+ first_para = " ".join(["word"] * 5)
+ second_para = "Short text."
+ text = first_para + " " + second_para
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ chunker = chunker_class(document, get_text, max_chunk_size=10)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 2, "Should produce 2 chunks due to overflow"
+ assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph"
+ assert chunks[1].text.strip() == second_para, (
+ "Second chunk should contain only second paragraph"
+ )
+ assert chunks[0].chunk_index == 0, "First chunk should have index 0"
+ assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
+
+
+@pytest.mark.asyncio
+async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator):
+ """Multiple small paragraphs should batch together with joiner spaces counted."""
+ paragraphs = [" ".join(["word"] * 12) for _ in range(40)]
+ text = " ".join(paragraphs)
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ chunker = chunker_class(document, get_text, max_chunk_size=49)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 20, (
+ "Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)"
+ )
+ assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
+ "Chunk indices should be sequential"
+ )
+ all_text = " ".join(chunk.text.strip() for chunk in chunks)
+ expected_text = " ".join(paragraphs)
+ assert all_text == expected_text, "All paragraph text should be preserved with correct spacing"
+
+
+@pytest.mark.asyncio
+async def test_alternating_large_and_small_paragraphs_dont_batch(
+ chunker_class, make_text_generator
+):
+ """Alternating near-max and small paragraphs should each become separate chunks."""
+ large1 = "word" * 15 + "."
+ small1 = "Short."
+ large2 = "word" * 15 + "."
+ small2 = "Tiny."
+ text = large1 + " " + small1 + " " + large2 + " " + small2
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ max_chunk_size = 10
+ get_text = make_text_generator(text)
+ chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size)
+ chunks = await collect_chunks(chunker)
+
+ assert len(chunks) == 4, "Should produce multiple chunks"
+ assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
+ "Chunk indices should be sequential"
+ )
+ assert chunks[0].chunk_size > max_chunk_size, (
+ "First chunk should be oversized (large paragraph)"
+ )
+ assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)"
+ assert chunks[2].chunk_size > max_chunk_size, (
+ "Third chunk should be oversized (large paragraph)"
+ )
+ assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)"
+
+
+@pytest.mark.asyncio
+async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator):
+ """Running chunker twice on identical input should produce identical indices and IDs."""
+ sentence1 = "one " * 4 + ". "
+ sentence2 = "two " * 4 + ". "
+ sentence3 = "one " * 4 + ". "
+ sentence4 = "two " * 4 + ". "
+ text = sentence1 + sentence2 + sentence3 + sentence4
+ doc_id = uuid4()
+ max_chunk_size = 20
+
+ document1 = Document(
+ id=doc_id,
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text1 = make_text_generator(text)
+ chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size)
+ chunks1 = await collect_chunks(chunker1)
+
+ document2 = Document(
+ id=doc_id,
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text2 = make_text_generator(text)
+ chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size)
+ chunks2 = await collect_chunks(chunker2)
+
+ assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
+ assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
+ assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]"
+ assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]"
+ assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic"
+ assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic"
+ assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run"
diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py
new file mode 100644
index 000000000..9d7be6936
--- /dev/null
+++ b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py
@@ -0,0 +1,324 @@
+"""Unit tests for TextChunkerWithOverlap overlap behavior."""
+
+import sys
+import pytest
+from uuid import uuid4
+from unittest.mock import patch
+
+from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
+from cognee.modules.data.processing.document_types import Document
+from cognee.tasks.chunks import chunk_by_paragraph
+
+
+@pytest.fixture
+def make_text_generator():
+ """Factory for async text generators."""
+
+ def _factory(*texts):
+ async def gen():
+ for text in texts:
+ yield text
+
+ return gen
+
+ return _factory
+
+
+@pytest.fixture
+def make_controlled_chunk_data():
+ """Factory for controlled chunk_data generators."""
+
+ def _factory(*sentences, chunk_size_per_sentence=10):
+ def _chunk_data(text):
+ return [
+ {
+ "text": sentence,
+ "chunk_size": chunk_size_per_sentence,
+ "cut_type": "sentence",
+ "chunk_id": uuid4(),
+ }
+ for sentence in sentences
+ ]
+
+ return _chunk_data
+
+ return _factory
+
+
+@pytest.mark.asyncio
+async def test_half_overlap_preserves_content_across_chunks(
+ make_text_generator, make_controlled_chunk_data
+):
+ """With 50% overlap, consecutive chunks should share half their content."""
+ s1 = "one"
+ s2 = "two"
+ s3 = "three"
+ s4 = "four"
+ text = "dummy"
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=20,
+ chunk_overlap_ratio=0.5,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)"
+ assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]"
+ assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2"
+ assert "two" in chunks[1].text and "three" in chunks[1].text, (
+ "Chunk 1 should contain s2 (overlap) and s3"
+ )
+ assert "three" in chunks[2].text and "four" in chunks[2].text, (
+ "Chunk 2 should contain s3 (overlap) and s4"
+ )
+
+
+@pytest.mark.asyncio
+async def test_zero_overlap_produces_no_duplicate_content(
+ make_text_generator, make_controlled_chunk_data
+):
+ """With 0% overlap, no content should appear in multiple chunks."""
+ s1 = "one"
+ s2 = "two"
+ s3 = "three"
+ s4 = "four"
+ text = "dummy"
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=20,
+ chunk_overlap_ratio=0.0,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)"
+ assert "one" in chunks[0].text and "two" in chunks[0].text, (
+ "First chunk should contain s1 and s2"
+ )
+ assert "three" in chunks[1].text and "four" in chunks[1].text, (
+ "Second chunk should contain s3 and s4"
+ )
+ assert "two" not in chunks[1].text and "three" not in chunks[0].text, (
+ "No overlap: end of chunk 0 should not appear in chunk 1"
+ )
+
+
+@pytest.mark.asyncio
+async def test_small_overlap_ratio_creates_minimal_overlap(
+ make_text_generator, make_controlled_chunk_data
+):
+ """With 25% overlap ratio, chunks should have minimal overlap."""
+ s1 = "alpha"
+ s2 = "beta"
+ s3 = "gamma"
+ s4 = "delta"
+ s5 = "epsilon"
+ text = "dummy"
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10)
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=40,
+ chunk_overlap_ratio=0.25,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 2, "Should produce exactly 2 chunks"
+ assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
+ assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
+ "Chunk 0 should contain s1 through s4"
+ )
+ assert s4 in chunks[1].text and s5 in chunks[1].text, (
+ "Chunk 1 should contain overlap s4 and new content s5"
+ )
+
+
+@pytest.mark.asyncio
+async def test_high_overlap_ratio_creates_significant_overlap(
+ make_text_generator, make_controlled_chunk_data
+):
+ """With 75% overlap ratio, consecutive chunks should share most content."""
+ s1 = "red"
+ s2 = "blue"
+ s3 = "green"
+ s4 = "yellow"
+ s5 = "purple"
+ text = "dummy"
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5)
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=20,
+ chunk_overlap_ratio=0.75,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap"
+ assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
+ assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
+ "Chunk 0 should contain s1, s2, s3, s4"
+ )
+ assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), (
+ "Chunk 1 should contain s2, s3, s4 (overlap) and s5"
+ )
+
+
+@pytest.mark.asyncio
+async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data):
+ """Text that fits in one chunk should produce exactly one chunk, no overlap artifact."""
+ s1 = "alpha"
+ s2 = "beta"
+ text = "dummy"
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+ get_text = make_text_generator(text)
+ get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10)
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=20,
+ chunk_overlap_ratio=0.5,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 1, (
+ "Should produce exactly 1 chunk when content fits within max_chunk_size"
+ )
+ assert chunks[0].chunk_index == 0, "Single chunk should have index 0"
+ assert "alpha" in chunks[0].text and "beta" in chunks[0].text, (
+ "Single chunk should contain all content"
+ )
+
+
+@pytest.mark.asyncio
+async def test_paragraph_chunking_with_overlap(make_text_generator):
+ """Test that chunk_by_paragraph integration produces 25% overlap between chunks."""
+
+ def mock_get_embedding_engine():
+ class MockEngine:
+ tokenizer = None
+
+ return MockEngine()
+
+ chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
+
+ max_chunk_size = 20
+ overlap_ratio = 0.25 # 5 token overlap
+ paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2
+
+ text = (
+ "A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9)
+ "B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19)
+ "C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29)
+ "D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39)
+ "E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49)
+ )
+
+ document = Document(
+ id=uuid4(),
+ name="test_document",
+ raw_data_location="/test/path",
+ external_metadata=None,
+ mime_type="text/plain",
+ )
+
+ get_text = make_text_generator(text)
+
+ def get_chunk_data(text_input):
+ return chunk_by_paragraph(
+ text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True
+ )
+
+ with patch.object(
+ chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
+ ):
+ chunker = TextChunkerWithOverlap(
+ document,
+ get_text,
+ max_chunk_size=max_chunk_size,
+ chunk_overlap_ratio=overlap_ratio,
+ get_chunk_data=get_chunk_data,
+ )
+ chunks = [chunk async for chunk in chunker.read()]
+
+ assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}"
+
+ assert chunks[0].chunk_index == 0, "First chunk should have index 0"
+ assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
+ assert chunks[2].chunk_index == 2, "Third chunk should have index 2"
+
+ assert "A0" in chunks[0].text, "Chunk 0 should start with A0"
+ assert "A9" in chunks[0].text, "Chunk 0 should contain A9"
+ assert "B0" in chunks[0].text, "Chunk 0 should contain B0"
+ assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)"
+
+ assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section"
+ assert "C" in chunks[1].text, "Chunk 1 should contain C section"
+ assert "D" in chunks[1].text, "Chunk 1 should contain D section"
+
+ assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section"
+ assert "E0" in chunks[2].text, "Chunk 2 should contain E0"
+ assert "E9" in chunks[2].text, "Chunk 2 should end with E9"
+
+ chunk_0_end_words = chunks[0].text.split()[-4:]
+ chunk_1_words = chunks[1].text.split()
+ overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words)
+ assert overlap_0_1, (
+ f"No overlap detected between chunks 0 and 1. "
+ f"Chunk 0 ends with: {chunk_0_end_words}, "
+ f"Chunk 1 starts with: {chunk_1_words[:6]}"
+ )
+
+ chunk_1_end_words = chunks[1].text.split()[-4:]
+ chunk_2_words = chunks[2].text.split()
+ overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words)
+ assert overlap_1_2, (
+ f"No overlap detected between chunks 1 and 2. "
+ f"Chunk 1 ends with: {chunk_1_end_words}, "
+ f"Chunk 2 starts with: {chunk_2_words[:6]}"
+ )
diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
index 37ba113b5..1d2b79cf9 100644
--- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
+++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
@@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1"
- assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
+ assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
assert len(node.status) == 2
assert np.all(node.status == 1)
@@ -96,7 +96,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1
assert edge.node2 == node2
- assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
+ assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
assert edge.directed is False
assert len(edge.status) == 2
assert np.all(edge.status == 1)
diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py
index 6888648c3..711479387 100644
--- a/cognee/tests/unit/modules/graph/cognee_graph_test.py
+++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py
@@ -1,4 +1,5 @@
import pytest
+from unittest.mock import AsyncMock
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@@ -11,6 +12,30 @@ def setup_graph():
return CogneeGraph()
+@pytest.fixture
+def mock_adapter():
+ """Fixture to create a mock adapter for database operations."""
+ adapter = AsyncMock()
+ return adapter
+
+
+@pytest.fixture
+def mock_vector_engine():
+ """Fixture to create a mock vector engine."""
+ engine = AsyncMock()
+ engine.search = AsyncMock()
+ return engine
+
+
+class MockScoredResult:
+ """Mock class for vector search results."""
+
+ def __init__(self, id, score, payload=None):
+ self.id = id
+ self.score = score
+ self.payload = payload or {}
+
+
def test_add_node_success(setup_graph):
"""Test successful addition of a node."""
graph = setup_graph
@@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
graph = setup_graph
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges_from_node("nonexistent")
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
+ """Test projecting a full graph from database."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1", "description": "First node"}),
+ ("2", {"name": "Node2", "description": "Second node"}),
+ ]
+ edges_data = [
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name", "description"],
+ edge_properties_to_project=["relationship_name"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert len(graph.edges) == 1
+ assert graph.get_node("1") is not None
+ assert graph.get_node("2") is not None
+ assert graph.edges[0].node1.id == "1"
+ assert graph.edges[0].node2.id == "2"
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
+ """Test projecting an ID-filtered graph from database."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1"}),
+ ("2", {"name": "Node2"}),
+ ]
+ edges_data = [
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=["relationship_name"],
+ relevant_ids_to_filter=["1", "2"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert len(graph.edges) == 1
+ mock_adapter.get_id_filtered_graph_data.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
+ """Test projecting a nodeset subgraph filtered by node type and name."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Alice", "type": "Person"}),
+ ("2", {"name": "Bob", "type": "Person"}),
+ ]
+ edges_data = [
+ ("1", "2", "KNOWS", {"relationship_name": "knows"}),
+ ]
+
+ mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name", "type"],
+ edge_properties_to_project=["relationship_name"],
+ node_type="Person",
+ node_name=["Alice"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert graph.get_node("1") is not None
+ assert len(graph.edges) == 1
+ mock_adapter.get_nodeset_subgraph.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
+ """Test projecting empty graph raises EntityNotFoundError."""
+ graph = setup_graph
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
+
+ with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=[],
+ )
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
+ """Test that edges referencing missing nodes raise error."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1"}),
+ ]
+ edges_data = [
+ ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=["relationship_name"],
+ )
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_nodes(setup_graph):
+ """Test mapping vector distances to graph nodes."""
+ graph = setup_graph
+
+ node1 = Node("1", {"name": "Node1"})
+ node2 = Node("2", {"name": "Node2"})
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ]
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_partial_node_coverage(setup_graph):
+ """Test mapping vector distances when only some nodes have results."""
+ graph = setup_graph
+
+ node1 = Node("1", {"name": "Node1"})
+ node2 = Node("2", {"name": "Node2"})
+ node3 = Node("3", {"name": "Node3"})
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ]
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+ assert graph.get_node("3").attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_multiple_categories(setup_graph):
+ """Test mapping vector distances from multiple collection categories."""
+ graph = setup_graph
+
+ # Create nodes
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ node4 = Node("4")
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+ graph.add_node(node4)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ],
+ "TextSummary_text": [
+ MockScoredResult("3", 0.92),
+ ],
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+ assert graph.get_node("3").attributes.get("vector_distance") == 0.92
+ assert graph.get_node("4").attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
+ """Test mapping vector distances to edges when edge_distances provided."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
+ """Test mapping edge distances when searching for them."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ mock_vector_engine.search.return_value = [
+ MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=None,
+ )
+
+ mock_vector_engine.search.assert_called_once()
+ assert graph.edges[0].attributes.get("vector_distance") == 0.88
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
+ """Test mapping edge distances when only some edges have results."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+
+ edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
+ edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
+ graph.add_edge(edge1)
+ graph.add_edge(edge2)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
+ assert graph.edges[1].attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_edges_fallback_to_relationship_type(
+ setup_graph, mock_vector_engine
+):
+ """Test that edge mapping falls back to relationship_type when edge_text is missing."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"relationship_type": "KNOWS"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.85
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
+ """Test edge mapping when no edges match the distance results."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
+ """Test that invalid query vector raises error."""
+ graph = setup_graph
+
+ with pytest.raises(ValueError, match="Failed to generate query embedding"):
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[],
+ edge_distances=None,
+ )
+
+
+@pytest.mark.asyncio
+async def test_calculate_top_triplet_importances(setup_graph):
+ """Test calculating top triplet importances by score."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ node4 = Node("4")
+
+ node1.add_attribute("vector_distance", 0.9)
+ node2.add_attribute("vector_distance", 0.8)
+ node3.add_attribute("vector_distance", 0.7)
+ node4.add_attribute("vector_distance", 0.6)
+
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+ graph.add_node(node4)
+
+ edge1 = Edge(node1, node2)
+ edge2 = Edge(node2, node3)
+ edge3 = Edge(node3, node4)
+
+ edge1.add_attribute("vector_distance", 0.85)
+ edge2.add_attribute("vector_distance", 0.75)
+ edge3.add_attribute("vector_distance", 0.65)
+
+ graph.add_edge(edge1)
+ graph.add_edge(edge2)
+ graph.add_edge(edge3)
+
+ top_triplets = await graph.calculate_top_triplet_importances(k=2)
+
+ assert len(top_triplets) == 2
+
+ assert top_triplets[0] == edge3
+ assert top_triplets[1] == edge2
+
+
+@pytest.mark.asyncio
+async def test_calculate_top_triplet_importances_default_distances(setup_graph):
+ """Test calculating importances when nodes/edges have no vector distances."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(node1, node2)
+ graph.add_edge(edge)
+
+ top_triplets = await graph.calculate_top_triplet_importances(k=1)
+
+ assert len(top_triplets) == 1
+ assert top_triplets[0] == edge
diff --git a/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py
new file mode 100644
index 000000000..8c2448287
--- /dev/null
+++ b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py
@@ -0,0 +1,111 @@
+import pytest
+from unittest.mock import AsyncMock, patch
+
+from cognee.tasks.memify.cognify_session import cognify_session
+from cognee.exceptions import CogneeValidationError, CogneeSystemError
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_success():
+ """Test successful cognification of session data."""
+ session_data = (
+ "Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence"
+ )
+
+ with (
+ patch("cognee.add", new_callable=AsyncMock) as mock_add,
+ patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
+ ):
+ await cognify_session(session_data, dataset_id="123")
+
+ mock_add.assert_called_once_with(
+ session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
+ )
+ mock_cognify.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_empty_string():
+ """Test cognification fails with empty string."""
+ with pytest.raises(CogneeValidationError) as exc_info:
+ await cognify_session("")
+
+ assert "Session data cannot be empty" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_whitespace_string():
+ """Test cognification fails with whitespace-only string."""
+ with pytest.raises(CogneeValidationError) as exc_info:
+ await cognify_session(" \n\t ")
+
+ assert "Session data cannot be empty" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_none_data():
+ """Test cognification fails with None data."""
+ with pytest.raises(CogneeValidationError) as exc_info:
+ await cognify_session(None)
+
+ assert "Session data cannot be empty" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_add_failure():
+ """Test cognification handles cognee.add failure."""
+ session_data = "Session ID: test\n\nQuestion: test?"
+
+ with (
+ patch("cognee.add", new_callable=AsyncMock) as mock_add,
+ patch("cognee.cognify", new_callable=AsyncMock),
+ ):
+ mock_add.side_effect = Exception("Add operation failed")
+
+ with pytest.raises(CogneeSystemError) as exc_info:
+ await cognify_session(session_data)
+
+ assert "Failed to cognify session data" in str(exc_info.value)
+ assert "Add operation failed" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_cognify_failure():
+ """Test cognification handles cognify failure."""
+ session_data = "Session ID: test\n\nQuestion: test?"
+
+ with (
+ patch("cognee.add", new_callable=AsyncMock),
+ patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
+ ):
+ mock_cognify.side_effect = Exception("Cognify operation failed")
+
+ with pytest.raises(CogneeSystemError) as exc_info:
+ await cognify_session(session_data)
+
+ assert "Failed to cognify session data" in str(exc_info.value)
+ assert "Cognify operation failed" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_re_raises_validation_error():
+ """Test that CogneeValidationError is re-raised as-is."""
+ with pytest.raises(CogneeValidationError):
+ await cognify_session("")
+
+
+@pytest.mark.asyncio
+async def test_cognify_session_with_special_characters():
+ """Test cognification with special characters."""
+ session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!"
+
+ with (
+ patch("cognee.add", new_callable=AsyncMock) as mock_add,
+ patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
+ ):
+ await cognify_session(session_data, dataset_id="123")
+
+ mock_add.assert_called_once_with(
+ session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
+ )
+ mock_cognify.assert_called_once()
diff --git a/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py
new file mode 100644
index 000000000..8cb27fef3
--- /dev/null
+++ b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py
@@ -0,0 +1,175 @@
+import sys
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from cognee.tasks.memify.extract_user_sessions import extract_user_sessions
+from cognee.exceptions import CogneeSystemError
+from cognee.modules.users.models import User
+
+# Get the actual module object (not the function) for patching
+extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"]
+
+
+@pytest.fixture
+def mock_user():
+ """Create a mock user."""
+ user = MagicMock(spec=User)
+ user.id = "test-user-123"
+ return user
+
+
+@pytest.fixture
+def mock_qa_data():
+ """Create mock Q&A data."""
+ return [
+ {
+ "question": "What is cognee?",
+ "context": "context about cognee",
+ "answer": "Cognee is a knowledge graph solution",
+ "time": "2025-01-01T12:00:00",
+ },
+ {
+ "question": "How does it work?",
+ "context": "how it works context",
+ "answer": "It processes data and creates graphs",
+ "time": "2025-01-01T12:05:00",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_success(mock_user, mock_qa_data):
+ """Test successful extraction of sessions."""
+ mock_cache_engine = AsyncMock()
+ mock_cache_engine.get_all_qas.return_value = mock_qa_data
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions([{}], session_ids=["test_session"]):
+ sessions.append(session)
+
+ assert len(sessions) == 1
+ assert "Session ID: test_session" in sessions[0]
+ assert "Question: What is cognee?" in sessions[0]
+ assert "Answer: Cognee is a knowledge graph solution" in sessions[0]
+ assert "Question: How does it work?" in sessions[0]
+ assert "Answer: It processes data and creates graphs" in sessions[0]
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data):
+ """Test extraction of multiple sessions."""
+ mock_cache_engine = AsyncMock()
+ mock_cache_engine.get_all_qas.return_value = mock_qa_data
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]):
+ sessions.append(session)
+
+ assert len(sessions) == 2
+ assert mock_cache_engine.get_all_qas.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_no_data(mock_user, mock_qa_data):
+ """Test extraction handles empty data parameter."""
+ mock_cache_engine = AsyncMock()
+ mock_cache_engine.get_all_qas.return_value = mock_qa_data
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions(None, session_ids=["test_session"]):
+ sessions.append(session)
+
+ assert len(sessions) == 1
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_no_session_ids(mock_user):
+ """Test extraction handles no session IDs provided."""
+ mock_cache_engine = AsyncMock()
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions([{}], session_ids=None):
+ sessions.append(session)
+
+ assert len(sessions) == 0
+ mock_cache_engine.get_all_qas.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_empty_qa_data(mock_user):
+ """Test extraction handles empty Q&A data."""
+ mock_cache_engine = AsyncMock()
+ mock_cache_engine.get_all_qas.return_value = []
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions([{}], session_ids=["empty_session"]):
+ sessions.append(session)
+
+ assert len(sessions) == 0
+
+
+@pytest.mark.asyncio
+async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data):
+ """Test extraction continues on cache error for specific session."""
+ mock_cache_engine = AsyncMock()
+ mock_cache_engine.get_all_qas.side_effect = [
+ mock_qa_data,
+ Exception("Cache error"),
+ mock_qa_data,
+ ]
+
+ with (
+ patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
+ patch.object(
+ extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
+ ),
+ ):
+ mock_session_user.get.return_value = mock_user
+
+ sessions = []
+ async for session in extract_user_sessions(
+ [{}], session_ids=["session1", "session2", "session3"]
+ ):
+ sessions.append(session)
+
+ assert len(sessions) == 2
diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
index 7fcfe0d6b..206cfaf84 100644
--- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
+++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
@@ -2,7 +2,6 @@ import os
import pytest
import pathlib
from typing import Optional, Union
-from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
-class TestAnswer(BaseModel):
- answer: str
- explanation: str
-
-
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
-
- @pytest.mark.asyncio
- async def test_get_structured_completion(self):
- system_directory_path = os.path.join(
- pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
- )
- cognee.config.system_root_directory(system_directory_path)
- data_directory_path = os.path.join(
- pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
- )
- cognee.config.data_root_directory(data_directory_path)
-
- await cognee.prune.prune_data()
- await cognee.prune.prune_system(metadata=True)
- await setup()
-
- class Company(DataPoint):
- name: str
-
- class Person(DataPoint):
- name: str
- works_for: Company
-
- company1 = Company(name="Figma")
- person1 = Person(name="Steve Rodger", works_for=company1)
-
- entities = [company1, person1]
- await add_data_points(entities)
-
- retriever = GraphCompletionCotRetriever()
-
- # Test with string response model (default)
- string_answer = await retriever.get_structured_completion("Who works at Figma?")
- assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
- assert string_answer.strip(), "Answer should not be empty"
-
- # Test with structured response model
- structured_answer = await retriever.get_structured_completion(
- "Who works at Figma?", response_model=TestAnswer
- )
- assert isinstance(structured_answer, TestAnswer), (
- f"Expected TestAnswer, got {type(structured_answer).__name__}"
- )
- assert structured_answer.answer.strip(), "Answer field should not be empty"
- assert structured_answer.explanation.strip(), "Explanation field should not be empty"
diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py
index 252af8352..9bfed68f3 100644
--- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py
+++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py
@@ -3,6 +3,7 @@ from typing import List
import pytest
import pathlib
import cognee
+
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py
new file mode 100644
index 000000000..4ad3019ff
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py
@@ -0,0 +1,204 @@
+import asyncio
+
+import pytest
+import cognee
+import pathlib
+import os
+
+from pydantic import BaseModel
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.modules.chunking.models import DocumentChunk
+from cognee.modules.data.processing.document_types import TextDocument
+from cognee.modules.engine.models import Entity, EntityType
+from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
+from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
+from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
+ GraphCompletionContextExtensionRetriever,
+)
+from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
+from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
+from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+
+
+class TestAnswer(BaseModel):
+ answer: str
+ explanation: str
+
+
+def _assert_string_answer(answer: list[str]):
+ assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
+ assert all(item.strip() for item in answer), "Items should not be empty"
+
+
+def _assert_structured_answer(answer: list[TestAnswer]):
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
+ assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
+ assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
+
+
+async def _test_get_structured_graph_completion_cot():
+ retriever = GraphCompletionCotRetriever()
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("Who works at Figma?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "Who works at Figma?", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+async def _test_get_structured_graph_completion():
+ retriever = GraphCompletionRetriever()
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("Who works at Figma?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "Who works at Figma?", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+async def _test_get_structured_graph_completion_temporal():
+ retriever = TemporalRetriever()
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("When did Steve start working at Figma?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "When did Steve start working at Figma??", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+async def _test_get_structured_graph_completion_rag():
+ retriever = CompletionRetriever()
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("Where does Steve work?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "Where does Steve work?", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+async def _test_get_structured_graph_completion_context_extension():
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("Who works at Figma?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "Who works at Figma?", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+async def _test_get_structured_entity_completion():
+ retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
+
+ # Test with string response model (default)
+ string_answer = await retriever.get_completion("Who is Albert Einstein?")
+ _assert_string_answer(string_answer)
+
+ # Test with structured response model
+ structured_answer = await retriever.get_completion(
+ "Who is Albert Einstein?", response_model=TestAnswer
+ )
+ _assert_structured_answer(structured_answer)
+
+
+class TestStructuredOutputCompletion:
+ @pytest.mark.asyncio
+ async def test_get_structured_completion(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ works_since: int
+
+ company1 = Company(name="Figma")
+ person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
+
+ entities = [company1, person1]
+ await add_data_points(entities)
+
+ document = TextDocument(
+ name="Steve Rodger's career",
+ raw_data_location="somewhere",
+ external_metadata="",
+ mime_type="text/plain",
+ )
+
+ chunk1 = DocumentChunk(
+ text="Steve Rodger",
+ chunk_size=2,
+ chunk_index=0,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk2 = DocumentChunk(
+ text="Mike Broski",
+ chunk_size=2,
+ chunk_index=1,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+ chunk3 = DocumentChunk(
+ text="Christina Mayer",
+ chunk_size=2,
+ chunk_index=2,
+ cut_type="sentence_end",
+ is_part_of=document,
+ contains=[],
+ )
+
+ entities = [chunk1, chunk2, chunk3]
+ await add_data_points(entities)
+
+ entity_type = EntityType(name="Person", description="A human individual")
+ entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
+
+ entities = [entity]
+ await add_data_points(entities)
+
+ await _test_get_structured_graph_completion_cot()
+ await _test_get_structured_graph_completion()
+ await _test_get_structured_graph_completion_temporal()
+ await _test_get_structured_graph_completion_rag()
+ await _test_get_structured_graph_completion_context_extension()
+ await _test_get_structured_entity_completion()
diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
index fc96081bf..5f4b93425 100644
--- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
+++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
@@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
-class TextSummariesRetriever:
+class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(
diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
index a322cb237..c3c6a47f6 100644
--- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
+++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
@@ -1,4 +1,3 @@
-import asyncio
from types import SimpleNamespace
import pytest
diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py
new file mode 100644
index 000000000..5eb6fb105
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py
@@ -0,0 +1,582 @@
+import pytest
+from unittest.mock import AsyncMock, patch
+
+from cognee.modules.retrieval.utils.brute_force_triplet_search import (
+ brute_force_triplet_search,
+ get_memory_fragment,
+)
+from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
+from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
+
+
+class MockScoredResult:
+ """Mock class for vector search results."""
+
+ def __init__(self, id, score, payload=None):
+ self.id = id
+ self.score = score
+ self.payload = payload or {}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_empty_query():
+ """Test that empty query raises ValueError."""
+ with pytest.raises(ValueError, match="The query must be a non-empty string."):
+ await brute_force_triplet_search(query="")
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_none_query():
+ """Test that None query raises ValueError."""
+ with pytest.raises(ValueError, match="The query must be a non-empty string."):
+ await brute_force_triplet_search(query=None)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_negative_top_k():
+ """Test that negative top_k raises ValueError."""
+ with pytest.raises(ValueError, match="top_k must be a positive integer."):
+ await brute_force_triplet_search(query="test query", top_k=-1)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_zero_top_k():
+ """Test that zero top_k raises ValueError."""
+ with pytest.raises(ValueError, match="top_k must be a positive integer."):
+ await brute_force_triplet_search(query="test query", top_k=0)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_limit_global_search():
+ """Test that wide_search_limit is applied for global search (node_name=None)."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=None, # Global search
+ wide_search_top_k=75,
+ )
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] == 75
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
+ """Test that wide_search_limit is None for filtered search (node_name provided)."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=["Node1"],
+ wide_search_top_k=50,
+ )
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] is None
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_default():
+ """Test that wide_search_top_k defaults to 100."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] == 100
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_default_collections():
+ """Test that default collections are used when none provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test")
+
+ expected_collections = [
+ "Entity_name",
+ "TextSummary_text",
+ "EntityType_name",
+ "DocumentChunk_text",
+ ]
+
+ call_collections = [
+ call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
+ ]
+ assert call_collections == expected_collections
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_custom_collections():
+ """Test that custom collections are used when provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ custom_collections = ["CustomCol1", "CustomCol2"]
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test", collections=custom_collections)
+
+ call_collections = [
+ call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
+ ]
+ assert call_collections == custom_collections
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_all_collections_empty():
+ """Test that empty list is returned when all collections return no results."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ results = await brute_force_triplet_search(query="test")
+ assert results == []
+
+
+# Tests for query embedding
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_embeds_query():
+ """Test that query is embedded before searching."""
+ query_text = "test query"
+ expected_vector = [0.1, 0.2, 0.3]
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query=query_text)
+
+ mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["query_vector"] == expected_vector
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_extracts_node_ids_global_search():
+ """Test that node IDs are extracted from search results for global search."""
+ scored_results = [
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ MockScoredResult("node3", 0.92),
+ ]
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=scored_results)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_reuses_provided_fragment():
+ """Test that provided memory fragment is reused instead of creating new one."""
+ provided_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
+ ) as mock_get_fragment,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ memory_fragment=provided_fragment,
+ node_name=["node"],
+ )
+
+ mock_get_fragment.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
+ """Test that memory fragment is created when not provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment,
+ ):
+ await brute_force_triplet_search(query="test", node_name=["node"])
+
+ mock_get_fragment.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
+ """Test that custom top_k is passed to importance calculation."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ),
+ ):
+ custom_top_k = 15
+ await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
+
+ mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
+
+
+@pytest.mark.asyncio
+async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
+ """Test that get_memory_fragment returns empty graph when entity not found."""
+ mock_graph_engine = AsyncMock()
+ mock_graph_engine.project_graph_from_db = AsyncMock(
+ side_effect=EntityNotFoundError("Entity not found")
+ )
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
+ return_value=mock_graph_engine,
+ ):
+ fragment = await get_memory_fragment()
+
+ assert isinstance(fragment, CogneeGraph)
+ assert len(fragment.nodes) == 0
+
+
+@pytest.mark.asyncio
+async def test_get_memory_fragment_returns_empty_graph_on_error():
+ """Test that get_memory_fragment returns empty graph on generic error."""
+ mock_graph_engine = AsyncMock()
+ mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
+ return_value=mock_graph_engine,
+ ):
+ fragment = await get_memory_fragment()
+
+ assert isinstance(fragment, CogneeGraph)
+ assert len(fragment.nodes) == 0
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_deduplicates_node_ids():
+ """Test that duplicate node IDs across collections are deduplicated."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ ]
+ elif collection_name == "TextSummary_text":
+ return [
+ MockScoredResult("node1", 0.90),
+ MockScoredResult("node3", 0.92),
+ ]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
+ assert len(call_kwargs["relevant_ids_to_filter"]) == 3
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_excludes_edge_collection():
+ """Test that EdgeType_relationship_name collection is excluded from ID extraction."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [MockScoredResult("node1", 0.95)]
+ elif collection_name == "EdgeType_relationship_name":
+ return [MockScoredResult("edge1", 0.88)]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=None,
+ collections=["Entity_name", "EdgeType_relationship_name"],
+ )
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_skips_nodes_without_ids():
+ """Test that nodes without ID attribute are skipped."""
+
+ class ScoredResultNoId:
+ """Mock result without id attribute."""
+
+ def __init__(self, score):
+ self.score = score
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [
+ MockScoredResult("node1", 0.95),
+ ScoredResultNoId(0.90),
+ MockScoredResult("node2", 0.87),
+ ]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_handles_tuple_results():
+ """Test that both list and tuple results are handled correctly."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return (
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ )
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_mixed_empty_collections():
+ """Test ID extraction with mixed empty and non-empty collections."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [MockScoredResult("node1", 0.95)]
+ elif collection_name == "TextSummary_text":
+ return []
+ elif collection_name == "EntityType_name":
+ return [MockScoredResult("node2", 0.92)]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
diff --git a/cognee/tests/unit/modules/users/test_conditional_authentication.py b/cognee/tests/unit/modules/users/test_conditional_authentication.py
index c4368d796..6568c3cb0 100644
--- a/cognee/tests/unit/modules/users/test_conditional_authentication.py
+++ b/cognee/tests/unit/modules/users/test_conditional_authentication.py
@@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
# REQUIRE_AUTHENTICATION should be a boolean
assert isinstance(REQUIRE_AUTHENTICATION, bool)
- # Currently should be False (optional authentication)
- assert not REQUIRE_AUTHENTICATION
-
class TestConditionalAuthenticationEnvironmentVariables:
"""Test environment variable handling."""
- def test_require_authentication_default_false(self):
- """Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
- with patch.dict(os.environ, {}, clear=True):
- # Remove module from cache to force fresh import
- module_name = "cognee.modules.users.methods.get_authenticated_user"
- if module_name in sys.modules:
- del sys.modules[module_name]
-
- # Import after patching environment - module will see empty environment
- from cognee.modules.users.methods.get_authenticated_user import (
- REQUIRE_AUTHENTICATION,
- )
-
- importlib.invalidate_caches()
- assert not REQUIRE_AUTHENTICATION
-
def test_require_authentication_true(self):
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
@@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
assert REQUIRE_AUTHENTICATION
- def test_require_authentication_false_explicit(self):
- """Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
- with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
- # Remove module from cache to force fresh import
- module_name = "cognee.modules.users.methods.get_authenticated_user"
- if module_name in sys.modules:
- del sys.modules[module_name]
-
- # Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
- from cognee.modules.users.methods.get_authenticated_user import (
- REQUIRE_AUTHENTICATION,
- )
-
- assert not REQUIRE_AUTHENTICATION
-
- def test_require_authentication_case_insensitive(self):
- """Test that environment variable parsing is case insensitive when imported."""
- test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
-
- for case in test_cases:
- with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
- # Remove module from cache to force fresh import
- module_name = "cognee.modules.users.methods.get_authenticated_user"
- if module_name in sys.modules:
- del sys.modules[module_name]
-
- # Import after patching environment
- from cognee.modules.users.methods.get_authenticated_user import (
- REQUIRE_AUTHENTICATION,
- )
-
- expected = case.lower() == "true"
- assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
-
- def test_current_require_authentication_value(self):
- """Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
- from cognee.modules.users.methods.get_authenticated_user import (
- REQUIRE_AUTHENTICATION,
- )
-
- # The module-level variable should currently be False (set at import time)
- assert isinstance(REQUIRE_AUTHENTICATION, bool)
- assert not REQUIRE_AUTHENTICATION
-
class TestConditionalAuthenticationEdgeCases:
"""Test edge cases and error scenarios."""
diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py
new file mode 100644
index 000000000..7d6a73a06
--- /dev/null
+++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py
@@ -0,0 +1,52 @@
+from itertools import product
+
+import numpy as np
+import pytest
+
+from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
+from cognee.tasks.chunks import chunk_by_row
+
+INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
+max_chunk_size_vals = [8, 32]
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
+ chunks = chunk_by_row(input_text, max_chunk_size)
+ reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
+ assert reconstructed_text == input_text, (
+ f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
+ )
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_row_chunk_length(input_text, max_chunk_size):
+ chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
+ embedding_engine = get_embedding_engine()
+
+ chunk_lengths = np.array(
+ [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
+ )
+
+ larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
+ assert np.all(chunk_lengths <= max_chunk_size), (
+ f"{max_chunk_size = }: {larger_chunks} are too large"
+ )
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
+ chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
+ chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
+ assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
+ f"{chunk_indices = } are not monotonically increasing"
+ )
diff --git a/docker-compose.yml b/docker-compose.yml
index 43d9b2607..472f24c21 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -13,7 +13,7 @@ services:
- DEBUG=false # Change to true if debugging
- HOST=0.0.0.0
- ENVIRONMENT=local
- - LOG_LEVEL=ERROR
+ - LOG_LEVEL=INFO
extra_hosts:
# Allows the container to reach your local machine using "host.docker.internal" instead of "localhost"
- "host.docker.internal:host-gateway"
diff --git a/entrypoint.sh b/entrypoint.sh
index bad9b7aa3..496825408 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -43,10 +43,10 @@ sleep 2
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
if [ "$DEBUG" = "true" ]; then
echo "Waiting for the debugger to attach..."
- debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
+ exec debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app
else
- gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
+ exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app
fi
else
- gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app
+ exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error --access-logfile - --error-logfile - cognee.api.client:app
fi
diff --git a/examples/python/agentic_reasoning_procurement_example.py b/examples/python/agentic_reasoning_procurement_example.py
index 5aa3caa70..4e9d2d7e4 100644
--- a/examples/python/agentic_reasoning_procurement_example.py
+++ b/examples/python/agentic_reasoning_procurement_example.py
@@ -168,7 +168,7 @@ async def run_procurement_example():
for q in questions:
print(f"Question: \n{q}")
results = await procurement_system.search_memory(q, search_categories=[category])
- top_answer = results[category][0]
+ top_answer = results[category][0]["search_result"][0]
print(f"Answer: \n{top_answer.strip()}\n")
research_notes[category].append({"question": q, "answer": top_answer})
diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py
index 431069050..1b476a2c3 100644
--- a/examples/python/code_graph_example.py
+++ b/examples/python/code_graph_example.py
@@ -1,5 +1,7 @@
import argparse
import asyncio
+import os
+
import cognee
from cognee import SearchType
from cognee.shared.logging_utils import setup_logging, ERROR
@@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path, include_docs):
+ # Disable permissions feature for this example
+ os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
+
run_status = False
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
run_status = run_status
diff --git a/examples/python/conversation_session_persistence_example.py b/examples/python/conversation_session_persistence_example.py
new file mode 100644
index 000000000..5346f5012
--- /dev/null
+++ b/examples/python/conversation_session_persistence_example.py
@@ -0,0 +1,98 @@
+import asyncio
+
+import cognee
+from cognee import visualize_graph
+from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
+ persist_sessions_in_knowledge_graph_pipeline,
+)
+from cognee.modules.search.types import SearchType
+from cognee.modules.users.methods import get_default_user
+from cognee.shared.logging_utils import get_logger
+
+logger = get_logger("conversation_session_persistence_example")
+
+
+async def main():
+ # NOTE: CACHING has to be enabled for this example to work
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system"
+ text_2 = "Germany is a country located next to the Netherlands"
+
+ await cognee.add([text_1, text_2])
+ await cognee.cognify()
+
+ question = "What can I use to create a knowledge graph?"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text=question,
+ )
+ print("\nSession ID: default_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ question = "You sure about that?"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION, query_text=question
+ )
+ print("\nSession ID: default_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ question = "This is awesome!"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION, query_text=question
+ )
+ print("\nSession ID: default_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ question = "Where is Germany?"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text=question,
+ session_id="different_session",
+ )
+ print("\nSession ID: different_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ question = "Next to which country again?"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text=question,
+ session_id="different_session",
+ )
+ print("\nSession ID: different_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ question = "So you remember everything I asked from you?"
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION,
+ query_text=question,
+ session_id="different_session",
+ )
+ print("\nSession ID: different_session")
+ print(f"Question: {question}")
+ print(f"Answer: {search_results}\n")
+
+ session_ids_to_persist = ["default_session", "different_session"]
+ default_user = await get_default_user()
+
+ await persist_sessions_in_knowledge_graph_pipeline(
+ user=default_user,
+ session_ids=session_ids_to_persist,
+ )
+
+ await visualize_graph()
+
+
+if __name__ == "__main__":
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(main())
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
diff --git a/examples/python/feedback_enrichment_minimal_example.py b/examples/python/feedback_enrichment_minimal_example.py
index 11ef20830..8954bd5f6 100644
--- a/examples/python/feedback_enrichment_minimal_example.py
+++ b/examples/python/feedback_enrichment_minimal_example.py
@@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph
- dataset="feedback_enrichment_minimal",
)
diff --git a/examples/python/memify_coding_agent_example.py b/examples/python/memify_coding_agent_example.py
index 1fd3b1528..4a087ba61 100644
--- a/examples/python/memify_coding_agent_example.py
+++ b/examples/python/memify_coding_agent_example.py
@@ -89,7 +89,7 @@ async def main():
)
print("Coding rules created by memify:")
- for coding_rule in coding_rules:
+ for coding_rule in coding_rules[0]["search_result"][0]:
print("- " + coding_rule)
# Visualize new graph with added memify context
diff --git a/examples/python/permissions_example.py b/examples/python/permissions_example.py
index 4f51b660f..c0b104023 100644
--- a/examples/python/permissions_example.py
+++ b/examples/python/permissions_example.py
@@ -3,6 +3,7 @@ import cognee
import pathlib
from cognee.modules.users.exceptions import PermissionDeniedError
+from cognee.modules.users.tenants.methods import select_tenant
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import create_user
@@ -116,6 +117,7 @@ async def main():
print(
"\nOperation started as user_2 to give read permission to user_1 for the dataset owned by user_2"
)
+
await authorized_give_permission_on_datasets(
user_1.id,
[quantum_dataset_id],
@@ -142,6 +144,9 @@ async def main():
print("User 2 is creating CogneeLab tenant/organization")
tenant_id = await create_tenant("CogneeLab", user_2.id)
+ print("User 2 is selecting CogneeLab tenant/organization as active tenant")
+ await select_tenant(user_id=user_2.id, tenant_id=tenant_id)
+
print("\nUser 2 is creating Researcher role")
role_id = await create_role(role_name="Researcher", owner_id=user_2.id)
@@ -157,23 +162,59 @@ async def main():
)
await add_user_to_role(user_id=user_3.id, role_id=role_id, owner_id=user_2.id)
+ print("\nOperation as user_3 to select CogneeLab tenant/organization as active tenant")
+ await select_tenant(user_id=user_3.id, tenant_id=tenant_id)
+
print(
- "\nOperation started as user_2 to give read permission to Researcher role for the dataset owned by user_2"
+ "\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by user_2"
+ )
+ # Even though the dataset owner is user_2, the dataset doesn't belong to the tenant/organization CogneeLab.
+ # So we can't assign permissions to it when we're acting in the CogneeLab tenant.
+ try:
+ await authorized_give_permission_on_datasets(
+ role_id,
+ [quantum_dataset_id],
+ "read",
+ user_2.id,
+ )
+ except PermissionDeniedError:
+ print(
+ "User 2 could not give permission to the role as the QUANTUM dataset is not part of the CogneeLab tenant"
+ )
+
+ print(
+ "We will now create a new QUANTUM dataset with the QUANTUM_COGNEE_LAB name in the CogneeLab tenant so that permissions can be assigned to the Researcher role inside the tenant/organization"
+ )
+ # We can re-create the QUANTUM dataset in the CogneeLab tenant. The old QUANTUM dataset is still owned by user_2 personally
+ # and can still be accessed by selecting the personal tenant for user 2.
+ from cognee.modules.users.methods import get_user
+
+ # Note: We need to update user_2 from the database to refresh its tenant context changes
+ user_2 = await get_user(user_2.id)
+ await cognee.add([text], dataset_name="QUANTUM_COGNEE_LAB", user=user_2)
+ quantum_cognee_lab_cognify_result = await cognee.cognify(["QUANTUM_COGNEE_LAB"], user=user_2)
+
+ # The recreated Quantum dataset will now have a different dataset_id as it's a new dataset in a different organization
+ quantum_cognee_lab_dataset_id = extract_dataset_id_from_cognify(
+ quantum_cognee_lab_cognify_result
+ )
+ print(
+ "\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by the CogneeLab tenant"
)
await authorized_give_permission_on_datasets(
role_id,
- [quantum_dataset_id],
+ [quantum_cognee_lab_dataset_id],
"read",
user_2.id,
)
# Now user_3 can read from QUANTUM dataset as part of the Researcher role after proper permissions have been assigned by the QUANTUM dataset owner, user_2.
- print("\nSearch result as user_3 on the dataset owned by user_2:")
+ print("\nSearch result as user_3 on the QUANTUM dataset owned by the CogneeLab organization:")
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
- user=user_1,
- dataset_ids=[quantum_dataset_id],
+ user=user_3,
+ dataset_ids=[quantum_cognee_lab_dataset_id],
)
for result in search_results:
print(f"{result}\n")
diff --git a/examples/python/relational_database_migration_example.py b/examples/python/relational_database_migration_example.py
index 7e87347bc..98482cb4b 100644
--- a/examples/python/relational_database_migration_example.py
+++ b/examples/python/relational_database_migration_example.py
@@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main():
+ # Disable backend access control to migrate relational data
+ os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
+
# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
diff --git a/examples/python/run_custom_pipeline_example.py b/examples/python/run_custom_pipeline_example.py
new file mode 100644
index 000000000..1ca1b4402
--- /dev/null
+++ b/examples/python/run_custom_pipeline_example.py
@@ -0,0 +1,84 @@
+import asyncio
+import cognee
+from cognee.modules.engine.operations.setup import setup
+from cognee.modules.users.methods import get_default_user
+from cognee.shared.logging_utils import setup_logging, INFO
+from cognee.modules.pipelines import Task
+from cognee.api.v1.search import SearchType
+
+# Prerequisites:
+# 1. Copy `.env.template` and rename it to `.env`.
+# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
+# LLM_API_KEY = "your_key_here"
+
+
+async def main():
+ # Create a clean slate for cognee -- reset data and system state
+ print("Resetting cognee data...")
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ print("Data reset complete.\n")
+
+ # Create relational database and tables
+ await setup()
+
+ # cognee knowledge graph will be created based on this text
+ text = """
+ Natural language processing (NLP) is an interdisciplinary
+ subfield of computer science and information retrieval.
+ """
+
+ print("Adding text to cognee:")
+ print(text.strip())
+
+ # Let's recreate the cognee add pipeline through the custom pipeline framework
+ from cognee.tasks.ingestion import ingest_data, resolve_data_directories
+
+ user = await get_default_user()
+
+ # Values for tasks need to be filled before calling the pipeline
+ add_tasks = [
+ Task(resolve_data_directories, include_subdirectories=True),
+ Task(
+ ingest_data,
+ "main_dataset",
+ user,
+ ),
+ ]
+ # Forward tasks to custom pipeline along with data and user information
+ await cognee.run_custom_pipeline(
+ tasks=add_tasks, data=text, user=user, dataset="main_dataset", pipeline_name="add_pipeline"
+ )
+ print("Text added successfully.\n")
+
+ # Use LLMs and cognee to create knowledge graph
+ from cognee.api.v1.cognify.cognify import get_default_tasks
+
+ cognify_tasks = await get_default_tasks(user=user)
+ print("Recreating existing cognify pipeline in custom pipeline to create knowledge graph...\n")
+ await cognee.run_custom_pipeline(
+ tasks=cognify_tasks, user=user, dataset="main_dataset", pipeline_name="cognify_pipeline"
+ )
+ print("Cognify process complete.\n")
+
+ query_text = "Tell me about NLP"
+ print(f"Searching cognee for insights with query: '{query_text}'")
+ # Query cognee for insights on the added text
+ search_results = await cognee.search(
+ query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
+ )
+
+ print("Search results:")
+ # Display results
+ for result_text in search_results:
+ print(result_text)
+
+
+if __name__ == "__main__":
+ logger = setup_logging(log_level=INFO)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(main())
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
diff --git a/examples/python/simple_example.py b/examples/python/simple_example.py
index c13e48f85..237a8295e 100644
--- a/examples/python/simple_example.py
+++ b/examples/python/simple_example.py
@@ -59,14 +59,6 @@ async def main():
for result_text in search_results:
print(result_text)
- # Example output:
- # ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
- # (...)
- # It represents nodes and relationships in the knowledge graph:
- # - The first element is the source node (e.g., 'natural language processing').
- # - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
- # - The third element is the target node (e.g., 'computer science').
-
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
diff --git a/pyproject.toml b/pyproject.toml
index 3af0599f1..81a5da471 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "cognee"
-version = "0.4.1"
+version = "0.5.0.dev0"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
@@ -58,6 +58,8 @@ dependencies = [
"websockets>=15.0.1,<16.0.0",
"mistralai>=1.9.10",
"tenacity>=9.0.0",
+ "fakeredis[lua]>=2.32.0",
+ "diskcache>=5.6.3",
]
[project.optional-dependencies]
@@ -155,7 +157,6 @@ Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[project.scripts]
-cognee = "cognee.cli._cognee:main"
cognee-cli = "cognee.cli._cognee:main"
[build-system]
@@ -168,7 +169,6 @@ exclude = [
"/dist",
"/.data",
"/.github",
- "/alembic",
"/deployment",
"/cognee-mcp",
"/cognee-frontend",
@@ -200,3 +200,8 @@ exclude = [
[tool.ruff.lint]
ignore = ["F401"]
+
+[dependency-groups]
+dev = [
+ "pytest-timeout>=2.4.0",
+]