merge dev

This commit is contained in:
Andrej Milicevic 2025-12-01 10:19:49 +01:00
commit f7e072f533
177 changed files with 6319 additions and 762 deletions

View file

@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="16384"
# Instructor's modes determine how structured data is requested from and extracted from LLM responses
# You can change this type (i.e. mode) via this env variable
# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
LLM_INSTRUCTOR_MODE=""
EMBEDDING_PROVIDER="openai"
EMBEDDING_MODEL="openai/text-embedding-3-large"
@ -169,8 +173,9 @@ REQUIRE_AUTHENTICATION=False
# Vector: LanceDB
# Graph: KuzuDB
#
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
ENABLE_BACKEND_ACCESS_CONTROL=False
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
# Disable mode when using not supported graph/vector databases.
ENABLE_BACKEND_ACCESS_CONTROL=True
################################################################################
# ☁️ Cloud Sync Settings

View file

@ -42,3 +42,8 @@ runs:
done
fi
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS
- name: Add telemetry identifier for telemetry test and in case telemetry is enabled by accident
shell: bash
run: |
echo "test-machine" > .anon_id

View file

@ -6,6 +6,14 @@ Please provide a clear, human-generated description of the changes in this PR.
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
-->
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)

View file

@ -75,6 +75,7 @@ jobs:
name: Run Unit Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -104,6 +105,7 @@ jobs:
name: Run Integration Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -132,6 +134,7 @@ jobs:
name: Run Simple Examples
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -161,6 +164,7 @@ jobs:
name: Run Simple Examples BAML
runs-on: ubuntu-22.04
env:
ENV: 'dev'
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
BAML_LLM_PROVIDER: openai
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
@ -198,6 +202,7 @@ jobs:
name: Run Basic Graph Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}

View file

@ -39,6 +39,7 @@ jobs:
name: CLI Unit Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -66,6 +67,7 @@ jobs:
name: CLI Integration Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -93,6 +95,7 @@ jobs:
name: CLI Functionality Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}

View file

@ -60,7 +60,7 @@ jobs:
- name: Run Neo4j Example
env:
ENV: dev
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -95,7 +95,7 @@ jobs:
- name: Run Kuzu Example
env:
ENV: dev
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -141,7 +141,7 @@ jobs:
- name: Run PGVector Example
env:
ENV: dev
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -226,7 +226,7 @@ jobs:
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run parallel databases test
- name: Run permissions test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@ -239,6 +239,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_permissions.py
test-multi-tenancy:
name: Test multi tenancy with different situations in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run multi tenancy test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_multi_tenancy.py
test-graph-edges:
name: Test graph edge ingestion
runs-on: ubuntu-22.04
@ -308,7 +333,7 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres)
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres/Redis)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@ -321,6 +346,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
CACHE_BACKEND: 'redis'
SHARED_KUZU_LOCK: true
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
@ -386,8 +412,8 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
run_conversation_sessions_test:
name: Conversation sessions test
run_conversation_sessions_test_redis:
name: Conversation sessions test (Redis)
runs-on: ubuntu-latest
defaults:
run:
@ -427,7 +453,60 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- name: Run Conversation session tests
- name: Run Conversation session tests (Redis)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
CACHE_BACKEND: 'redis'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_conversation_history.py
run_conversation_sessions_test_fs:
name: Conversation sessions test (FS)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "postgres"
- name: Run Conversation session tests (FS)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@ -440,6 +519,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
CACHE_BACKEND: 'fs'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'

View file

@ -21,6 +21,7 @@ jobs:
- name: Run Multimedia Example
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: uv run python ./examples/python/multimedia_example.py
@ -40,6 +41,7 @@ jobs:
- name: Run Evaluation Framework Example
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -69,6 +71,7 @@ jobs:
- name: Run Descriptive Graph Metrics Example
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -99,6 +102,7 @@ jobs:
- name: Run Dynamic Steps Tests
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -124,6 +128,7 @@ jobs:
- name: Run Temporal Example
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -149,6 +154,7 @@ jobs:
- name: Run Ontology Demo Example
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -174,6 +180,7 @@ jobs:
- name: Run Agentic Reasoning Example
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -199,6 +206,7 @@ jobs:
- name: Run Memify Tests
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -210,6 +218,32 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/memify_coding_agent_example.py
test-custom-pipeline:
name: Run Custom Pipeline Example
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Custom Pipeline Example
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/run_custom_pipeline_example.py
test-permissions-example:
name: Run Permissions Example
runs-on: ubuntu-22.04
@ -224,6 +258,7 @@ jobs:
- name: Run Memify Tests
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -249,6 +284,7 @@ jobs:
- name: Run Docling Test
env:
ENV: 'dev'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}

70
.github/workflows/load_tests.yml vendored Normal file
View file

@ -0,0 +1,70 @@
name: Load tests
permissions:
contents: read
on:
workflow_dispatch:
workflow_call:
secrets:
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
jobs:
test-load:
name: Test Load
runs-on: ubuntu-22.04
timeout-minutes: 60
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Verify File Descriptor Limit
run: ulimit -n
- name: Run Load Test
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: True
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
STORAGE_BACKEND: s3
AWS_REGION: eu-west-1
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
run: uv run python ./cognee/tests/test_load.py

17
.github/workflows/release_test.yml vendored Normal file
View file

@ -0,0 +1,17 @@
# Long-running, heavy and resource-consuming tests for release validation
name: Release Test Workflow
permissions:
contents: read
on:
workflow_dispatch:
pull_request:
branches:
- main
jobs:
load-tests:
name: Load Tests
uses: ./.github/workflows/load_tests.yml
secrets: inherit

View file

@ -84,6 +84,7 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
@ -135,6 +136,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'pgvector'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
@ -197,6 +199,7 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -10,6 +10,10 @@ on:
required: false
type: string
default: '["3.10.x", "3.12.x", "3.13.x"]'
os:
required: false
type: string
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
secrets:
LLM_PROVIDER:
required: true
@ -40,10 +44,11 @@ jobs:
run-unit-tests:
name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-15, windows-latest]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -76,10 +81,11 @@ jobs:
run-integration-tests:
name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -112,10 +118,11 @@ jobs:
run-library-test:
name: Library test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -148,10 +155,11 @@ jobs:
run-build-test:
name: Build test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -177,10 +185,11 @@ jobs:
run-soft-deletion-test:
name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -214,10 +223,11 @@ jobs:
run-hard-deletion-test:
name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out

View file

@ -1,4 +1,6 @@
name: Test Suites
permissions:
contents: read
on:
push:
@ -80,12 +82,22 @@ jobs:
uses: ./.github/workflows/notebooks_tests.yml
secrets: inherit
different-operating-systems-tests:
name: Operating System and Python Tests
different-os-tests-basic:
name: OS and Python Tests Ubuntu
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml
with:
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
os: '["ubuntu-22.04"]'
secrets: inherit
different-os-tests-extended:
name: OS and Python Tests Extended
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml
with:
python-versions: '["3.13.x"]'
os: '["macos-15", "windows-latest"]'
secrets: inherit
# Matrix-based vector database tests
@ -135,7 +147,8 @@ jobs:
e2e-tests,
graph-db-tests,
notebook-tests,
different-operating-systems-tests,
different-os-tests-basic,
different-os-tests-extended,
vector-db-tests,
example-tests,
llm-tests,
@ -155,7 +168,8 @@ jobs:
cli-tests,
graph-db-tests,
notebook-tests,
different-operating-systems-tests,
different-os-tests-basic,
different-os-tests-extended,
vector-db-tests,
example-tests,
db-examples-tests,
@ -176,7 +190,8 @@ jobs:
"${{ needs.cli-tests.result }}" == "success" &&
"${{ needs.graph-db-tests.result }}" == "success" &&
"${{ needs.notebook-tests.result }}" == "success" &&
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
"${{ needs.different-os-tests-basic.result }}" == "success" &&
"${{ needs.different-os-tests-extended.result }}" == "success" &&
"${{ needs.vector-db-tests.result }}" == "success" &&
"${{ needs.example-tests.result }}" == "success" &&
"${{ needs.db-examples-tests.result }}" == "success" &&

View file

@ -2,7 +2,7 @@ name: Weighted Edges Tests
on:
push:
branches: [ main, weighted_edges ]
branches: [ main, dev, weighted_edges ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
@ -10,7 +10,7 @@ on:
- 'examples/python/weighted_edges_example.py'
- '.github/workflows/weighted_edges_tests.yml'
pull_request:
branches: [ main ]
branches: [ main, dev ]
paths:
- 'cognee/modules/graph/utils/get_graph_from_model.py'
- 'cognee/infrastructure/engine/models/Edge.py'
@ -32,7 +32,7 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
steps:
- name: Check out repository
@ -67,14 +67,13 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
LLM_ENDPOINT: https://api.openai.com/v1/
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_ENDPOINT: https://api.openai.com/v1
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_VERSION: "2024-02-01"
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: text-embedding-3-small
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
EMBEDDING_API_VERSION: "2024-02-01"
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4
@ -108,14 +107,14 @@ jobs:
env:
LLM_PROVIDER: openai
LLM_MODEL: gpt-5-mini
LLM_ENDPOINT: https://api.openai.com/v1/
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_ENDPOINT: https://api.openai.com/v1
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_VERSION: "2024-02-01"
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: text-embedding-3-small
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
EMBEDDING_API_VERSION: "2024-02-01"
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4

View file

@ -87,11 +87,6 @@ db_engine = get_relational_engine()
print("Using database:", db_engine.db_uri)
if "sqlite" in db_engine.db_uri:
from cognee.infrastructure.utils.run_sync import run_sync
run_sync(db_engine.create_database())
config.set_section_option(
config.config_ini_section,
"SQLALCHEMY_DATABASE_URI",

View file

@ -10,6 +10,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
@ -26,7 +27,34 @@ def upgrade() -> None:
connection = op.get_bind()
inspector = sa.inspect(connection)
if op.get_context().dialect.name == "postgresql":
syncstatus_enum = postgresql.ENUM(
"STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
)
syncstatus_enum.create(op.get_bind(), checkfirst=True)
if "sync_operations" not in inspector.get_table_names():
if op.get_context().dialect.name == "postgresql":
syncstatus = postgresql.ENUM(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
)
else:
syncstatus = sa.Enum(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
)
# Table doesn't exist, create it normally
op.create_table(
"sync_operations",
@ -34,15 +62,7 @@ def upgrade() -> None:
sa.Column("run_id", sa.Text(), nullable=True),
sa.Column(
"status",
sa.Enum(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
),
syncstatus,
nullable=True,
),
sa.Column("progress_percentage", sa.Integer(), nullable=True),

View file

@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
def upgrade() -> None:
try:
await_only(create_default_user())
except UserAlreadyExists:
pass # It's fine if the default user already exists
pass
def downgrade() -> None:
await_only(delete_user("default_user@example.com"))
pass

View file

@ -0,0 +1,98 @@
"""Expand dataset database for multi user
Revision ID: 76625596c5c3
Revises: 211ab850ef3d
Create Date: 2025-10-30 12:55:20.239562
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "76625596c5c3"
down_revision: Union[str, None] = "c946955da633"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
vector_database_provider_column = _get_column(
insp, "dataset_database", "vector_database_provider"
)
if not vector_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"vector_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
)
graph_database_provider_column = _get_column(
insp, "dataset_database", "graph_database_provider"
)
if not graph_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"graph_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
)
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
if not vector_database_url_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
)
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
if not graph_database_url_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
)
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
if not vector_database_key_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
)
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
if not graph_database_key_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
)
def downgrade() -> None:
op.drop_column("dataset_database", "vector_database_provider")
op.drop_column("dataset_database", "graph_database_provider")
op.drop_column("dataset_database", "vector_database_url")
op.drop_column("dataset_database", "graph_database_url")
op.drop_column("dataset_database", "vector_database_key")
op.drop_column("dataset_database", "graph_database_key")

View file

@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
db_engine = get_relational_engine()
# we might want to delete this
await_only(db_engine.create_database())
pass
def downgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.delete_database())
pass

View file

@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
)
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
# Recreate ACLs table with default permissions set to datasets instead of documents
op.drop_table("acls")
dataset_id_column = _get_column(insp, "acls", "dataset_id")
if not dataset_id_column:
# Recreate ACLs table with default permissions set to datasets instead of documents
op.drop_table("acls")
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
)
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
)
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
dataset_table = _define_dataset_table()
datasets = conn.execute(sa.select(dataset_table)).fetchall()
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
dataset_table = _define_dataset_table()
datasets = conn.execute(sa.select(dataset_table)).fetchall()
if not datasets:
return
if not datasets:
return
acl_list = []
acl_list = []
for dataset in datasets:
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
for dataset in datasets:
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
acl_list.append(
_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
)
if acl_list:
op.bulk_insert(acls_table, acl_list)
if acl_list:
op.bulk_insert(acls_table, acl_list)
def downgrade() -> None:

View file

@ -0,0 +1,137 @@
"""Multi Tenant Support
Revision ID: c946955da633
Revises: 211ab850ef3d
Create Date: 2025-11-04 18:11:09.325158
"""
from typing import Sequence, Union
from datetime import datetime, timezone
from uuid import uuid4
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c946955da633"
down_revision: Union[str, None] = "211ab850ef3d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _now():
return datetime.now(timezone.utc)
def _define_user_table() -> sa.Table:
table = sa.Table(
"users",
sa.MetaData(),
sa.Column(
"id",
sa.UUID,
sa.ForeignKey("principals.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
),
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
)
return table
def _define_dataset_table() -> sa.Table:
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
table = sa.Table(
"datasets",
sa.MetaData(),
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
sa.Column("name", sa.Text),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
)
return table
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
dataset = _define_dataset_table()
user = _define_user_table()
if "user_tenants" not in insp.get_table_names():
# Define table with all necessary columns including primary key
user_tenants = op.create_table(
"user_tenants",
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
)
# Get all users with their tenant_id
user_data = conn.execute(
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
).fetchall()
# Insert into user_tenants table
if user_data:
op.bulk_insert(
user_tenants,
[
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
for user_id, tenant_id in user_data
],
)
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
if not tenant_id_column:
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
# Build subquery, select users.tenant_id for each dataset.owner_id
tenant_id_from_dataset_owner = (
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
)
if op.get_context().dialect.name == "sqlite":
# If column doesn't exist create new original_extension column and update from values of extension column
with op.batch_alter_table("datasets") as batch_op:
batch_op.execute(
dataset.update().values(
tenant_id=tenant_id_from_dataset_owner,
)
)
else:
conn = op.get_bind()
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_tenants")
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
op.drop_column("datasets", "tenant_id")
# ### end Alembic commands ###

View file

@ -194,7 +194,6 @@ async def cognify(
Prerequisites:
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
- **Data Added**: Must have data previously added via `cognee.add()`
- **Vector Database**: Must be accessible for embeddings storage
- **Graph Database**: Must be accessible for relationship storage
@ -1096,6 +1095,10 @@ async def main():
# Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
from cognee.modules.engine.operations.setup import setup
await setup()
# Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...")
migration_result = subprocess.run(

View file

@ -19,6 +19,7 @@ from .api.v1.add import add
from .api.v1.delete import delete
from .api.v1.cognify import cognify
from .modules.memify import memify
from .modules.run_custom_pipeline import run_custom_pipeline
from .api.v1.update import update
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets

View file

@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
from cognee.api.v1.search.routers import get_search_router
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
from cognee.api.v1.memify.routers import get_memify_router
from cognee.api.v1.add.routers import get_add_router
from cognee.api.v1.delete.routers import get_delete_router
@ -39,6 +40,8 @@ from cognee.api.v1.users.routers import (
)
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
# Ensure application logging is configured for container stdout/stderr
setup_logging()
logger = get_logger()
if os.getenv("ENV", "prod") == "prod":
@ -74,6 +77,9 @@ async def lifespan(app: FastAPI):
await get_default_user()
# Emit a clear startup message for docker logs
logger.info("Backend server has started")
yield
@ -258,6 +264,8 @@ app.include_router(
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])

View file

@ -82,7 +82,9 @@ def get_add_router() -> APIRouter:
datasetName,
user=user,
dataset_id=datasetId,
node_set=node_set if node_set else None,
node_set=node_set
if node_set != [""]
else None, # Transform default node_set endpoint value to None
)
if isinstance(add_run, PipelineRunErrored):

View file

@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
custom_prompt: Optional[str] = Field(
default="", description="Custom prompt for entity extraction and graph generation"
)
ontology_key: Optional[List[str]] = Field(
default=None, description="Reference to one or more previously uploaded ontologies"
)
def get_cognify_router() -> APIRouter:
@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
- **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
## Response
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
{
"datasets": ["research_papers", "documentation"],
"run_in_background": false,
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
"ontology_key": ["medical_ontology_v1"]
}
```
@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
)
from cognee.api.v1.cognify import cognify as cognee_cognify
from cognee.api.v1.ontologies.ontologies import OntologyService
try:
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
config_to_use = None
if payload.ontology_key:
ontology_service = OntologyService()
ontology_contents = ontology_service.get_ontology_contents(
payload.ontology_key, user
)
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
RDFLibOntologyResolver,
)
from io import StringIO
ontology_streams = [StringIO(content) for content in ontology_contents]
config_to_use: Config = {
"ontology_config": {
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
}
}
cognify_run = await cognee_cognify(
datasets,
user,
config=config_to_use,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
)

View file

@ -0,0 +1,4 @@
from .ontologies import OntologyService
from .routers.get_ontology_router import get_ontology_router
__all__ = ["OntologyService", "get_ontology_router"]

View file

@ -0,0 +1,183 @@
import os
import json
import tempfile
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class OntologyMetadata:
ontology_key: str
filename: str
size_bytes: int
uploaded_at: str
description: Optional[str] = None
class OntologyService:
def __init__(self):
pass
@property
def base_dir(self) -> Path:
return Path(tempfile.gettempdir()) / "ontologies"
def _get_user_dir(self, user_id: str) -> Path:
user_dir = self.base_dir / str(user_id)
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir
def _get_metadata_path(self, user_dir: Path) -> Path:
return user_dir / "metadata.json"
def _load_metadata(self, user_dir: Path) -> dict:
metadata_path = self._get_metadata_path(user_dir)
if metadata_path.exists():
with open(metadata_path, "r") as f:
return json.load(f)
return {}
def _save_metadata(self, user_dir: Path, metadata: dict):
metadata_path = self._get_metadata_path(user_dir)
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
async def upload_ontology(
self, ontology_key: str, file, user, description: Optional[str] = None
) -> OntologyMetadata:
if not file.filename.lower().endswith(".owl"):
raise ValueError("File must be in .owl format")
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
if ontology_key in metadata:
raise ValueError(f"Ontology key '{ontology_key}' already exists")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError("File size exceeds 10MB limit")
file_path = user_dir / f"{ontology_key}.owl"
with open(file_path, "wb") as f:
f.write(content)
ontology_metadata = {
"filename": file.filename,
"size_bytes": len(content),
"uploaded_at": datetime.now(timezone.utc).isoformat(),
"description": description,
}
metadata[ontology_key] = ontology_metadata
self._save_metadata(user_dir, metadata)
return OntologyMetadata(
ontology_key=ontology_key,
filename=file.filename,
size_bytes=len(content),
uploaded_at=ontology_metadata["uploaded_at"],
description=description,
)
async def upload_ontologies(
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
) -> List[OntologyMetadata]:
"""
Upload ontology files with their respective keys.
Args:
ontology_key: List of unique keys for each ontology
files: List of UploadFile objects (same length as keys)
user: Authenticated user
descriptions: Optional list of descriptions for each file
Returns:
List of OntologyMetadata objects for uploaded files
Raises:
ValueError: If keys duplicate, file format invalid, or array lengths don't match
"""
if len(ontology_key) != len(files):
raise ValueError("Number of keys must match number of files")
if len(set(ontology_key)) != len(ontology_key):
raise ValueError("Duplicate ontology keys not allowed")
if descriptions and len(descriptions) != len(files):
raise ValueError("Number of descriptions must match number of files")
results = []
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
for i, (key, file) in enumerate(zip(ontology_key, files)):
if key in metadata:
raise ValueError(f"Ontology key '{key}' already exists")
if not file.filename.lower().endswith(".owl"):
raise ValueError(f"File '{file.filename}' must be in .owl format")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
file_path = user_dir / f"{key}.owl"
with open(file_path, "wb") as f:
f.write(content)
ontology_metadata = {
"filename": file.filename,
"size_bytes": len(content),
"uploaded_at": datetime.now(timezone.utc).isoformat(),
"description": descriptions[i] if descriptions else None,
}
metadata[key] = ontology_metadata
results.append(
OntologyMetadata(
ontology_key=key,
filename=file.filename,
size_bytes=len(content),
uploaded_at=ontology_metadata["uploaded_at"],
description=descriptions[i] if descriptions else None,
)
)
self._save_metadata(user_dir, metadata)
return results
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
"""
Retrieve ontology content for one or more keys.
Args:
ontology_key: List of ontology keys to retrieve (can contain single item)
user: Authenticated user
Returns:
List of ontology content strings
Raises:
ValueError: If any ontology key not found
"""
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
contents = []
for key in ontology_key:
if key not in metadata:
raise ValueError(f"Ontology key '{key}' not found")
file_path = user_dir / f"{key}.owl"
if not file_path.exists():
raise ValueError(f"Ontology file for key '{key}' not found")
with open(file_path, "r", encoding="utf-8") as f:
contents.append(f.read())
return contents
def list_ontologies(self, user) -> dict:
user_dir = self._get_user_dir(str(user.id))
return self._load_metadata(user_dir)

View file

@ -0,0 +1,107 @@
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
from fastapi.responses import JSONResponse
from typing import Optional, List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from ..ontologies import OntologyService
def get_ontology_router() -> APIRouter:
router = APIRouter()
ontology_service = OntologyService()
@router.post("", response_model=dict)
async def upload_ontology(
ontology_key: str = Form(...),
ontology_file: List[UploadFile] = File(...),
descriptions: Optional[str] = Form(None),
user: User = Depends(get_authenticated_user),
):
"""
Upload ontology files with their respective keys for later use in cognify operations.
Supports both single and multiple file uploads:
- Single file: ontology_key=["key"], ontology_file=[file]
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
## Request Parameters
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
- **ontology_file** (List[UploadFile]): OWL format ontology files
- **descriptions** (Optional[str]): JSON array string of optional descriptions
## Response
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
## Error Codes
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
"Ontology Upload API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "POST /api/v1/ontologies",
"cognee_version": cognee_version,
},
)
try:
import json
ontology_keys = json.loads(ontology_key)
description_list = json.loads(descriptions) if descriptions else None
if not isinstance(ontology_keys, list):
raise ValueError("ontology_key must be a JSON array")
results = await ontology_service.upload_ontologies(
ontology_keys, ontology_file, user, description_list
)
return {
"uploaded_ontologies": [
{
"ontology_key": result.ontology_key,
"filename": result.filename,
"size_bytes": result.size_bytes,
"uploaded_at": result.uploaded_at,
"description": result.description,
}
for result in results
]
}
except (json.JSONDecodeError, ValueError) as e:
return JSONResponse(status_code=400, content={"error": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@router.get("", response_model=dict)
async def list_ontologies(user: User = Depends(get_authenticated_user)):
"""
List all uploaded ontologies for the authenticated user.
## Response
Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
## Error Codes
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
"Ontology List API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /api/v1/ontologies",
"cognee_version": cognee_version,
},
)
try:
metadata = ontology_service.list_ontologies(user)
return metadata
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
return router

View file

@ -1,15 +1,20 @@
from uuid import UUID
from typing import List
from typing import List, Union
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.api.DTO import InDTO
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
class SelectTenantDTO(InDTO):
tenant_id: UUID | None = None
def get_permissions_router() -> APIRouter:
permissions_router = APIRouter()
@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter:
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
)
@permissions_router.post("/tenants/select")
async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
"""
Select current tenant.
This endpoint selects a tenant with the specified UUID. Tenants are used
to organize users and resources in multi-tenant environments, providing
isolation and access control between different groups or organizations.
Sending a null/None value as tenant_id selects his default single user tenant
## Request Parameters
- **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
## Response
Returns a success message along with selected tenant id.
"""
send_telemetry(
"Permissions API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
"tenant_id": str(payload.tenant_id),
},
)
from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
return JSONResponse(
status_code=200,
content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
)
return permissions_router

View file

@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results

View file

@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
Processing Pipeline:
1. **Document Classification**: Identifies document types and structures
2. **Permission Validation**: Ensures user has processing rights
2. **Permission Validation**: Ensures user has processing rights
3. **Text Chunking**: Breaks content into semantically meaningful segments
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
5. **Relationship Detection**: Discovers connections between entities
@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
chunker_class = LangchainChunker
except ImportError:
fmt.warning("LangchainChunker not available, using TextChunker")
elif args.chunker == "CsvChunker":
try:
from cognee.modules.chunking.CsvChunker import CsvChunker
chunker_class = CsvChunker
except ImportError:
fmt.warning("CsvChunker not available, using TextChunker")
result = await cognee.cognify(
datasets=datasets,

View file

@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
]
# Chunker choices
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
# Output format choices
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]

View file

@ -4,6 +4,8 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,
@ -38,9 +69,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"""
base_config = get_base_config()
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if not backend_access_control_enabled():
return
user = await get_user(user_id)
@ -48,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
base_config = get_base_config()
data_root_directory = os.path.join(
base_config.data_root_directory, str(user.tenant_id or user.id)
)
@ -57,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# Set vector and graph database configuration based on dataset database information
vector_config = {
"vector_db_url": os.path.join(
databases_directory_path, dataset_database.vector_database_name
),
"vector_db_key": "",
"vector_db_provider": "lancedb",
"vector_db_provider": dataset_database.vector_database_provider,
"vector_db_url": dataset_database.vector_database_url,
"vector_db_key": dataset_database.vector_database_key,
"vector_db_name": dataset_database.vector_database_name,
}
graph_config = {
"graph_database_provider": "kuzu",
"graph_database_provider": dataset_database.graph_database_provider,
"graph_database_url": dataset_database.graph_database_url,
"graph_database_name": dataset_database.graph_database_name,
"graph_database_key": dataset_database.graph_database_key,
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
),

View file

@ -0,0 +1,29 @@
FROM python:3.11-slim
# Set environment variables
ENV PIP_NO_CACHE_DIR=true
ENV PATH="${PATH}:/root/.poetry/bin"
ENV PYTHONPATH=/app
ENV SKIP_MIGRATIONS=true
# System dependencies
RUN apt-get update && apt-get install -y \
gcc \
libpq-dev \
git \
curl \
build-essential \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY pyproject.toml poetry.lock README.md /app/
RUN pip install poetry
RUN poetry config virtualenvs.create false
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
COPY cognee/ /app/cognee
COPY distributed/ /app/distributed

View file

@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
############
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
if isinstance(retrieval_context, list):
retrieval_context = await retriever.convert_retrieved_objects_to_context(
triplets=retrieval_context
)
if isinstance(search_results, str):
search_results = [search_results]
#############
answer = {
"question": query_text,
"answer": search_results[0],

View file

@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")

View file

@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
deepeval_model: str = "gpt-5-mini"
deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True

View file

@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
import pathlib
from os import path
from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
image = (
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
.copy_local_file("pyproject.toml", "pyproject.toml")
.copy_local_file("poetry.lock", "poetry.lock")
.env(
{
"ENV": os.getenv("ENV"),
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
}
)
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
image = Image.from_dockerfile(
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
force_build=False,
).add_local_python_source("cognee")
@app.function(
image=image,
max_containers=10,
timeout=86400,
volumes={"/data": vol},
secrets=[modal.Secret.from_name("eval_secrets")],
)
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
benchmark="HotPotQA",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
answering_questions=True,
evaluating_answers=True,
calculate_metrics=True,
dashboard=True,
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,

View file

@ -1,6 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from functools import lru_cache
from typing import Optional
from typing import Optional, Literal
class CacheConfig(BaseSettings):
@ -15,6 +15,7 @@ class CacheConfig(BaseSettings):
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
"""
cache_backend: Literal["redis", "fs"] = "fs"
caching: bool = False
shared_kuzu_lock: bool = False
cache_host: str = "localhost"
@ -28,6 +29,7 @@ class CacheConfig(BaseSettings):
def to_dict(self) -> dict:
return {
"cache_backend": self.cache_backend,
"caching": self.caching,
"shared_kuzu_lock": self.shared_kuzu_lock,
"cache_host": self.cache_host,

View file

@ -0,0 +1,151 @@
import asyncio
import json
import os
from datetime import datetime
import time
import threading
import diskcache as dc
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
from cognee.infrastructure.databases.exceptions.exceptions import (
CacheConnectionError,
SharedKuzuLockRequiresRedisError,
)
from cognee.infrastructure.files.storage.get_storage_config import get_storage_config
from cognee.shared.logging_utils import get_logger
logger = get_logger("FSCacheAdapter")
class FSCacheAdapter(CacheDBInterface):
def __init__(self):
default_key = "sessions_db"
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
cache_directory = os.path.join(data_root_directory, ".cognee_fs_cache", default_key)
os.makedirs(cache_directory, exist_ok=True)
self.cache = dc.Cache(directory=cache_directory)
self.cache.expire()
logger.debug(f"FSCacheAdapter initialized with cache directory: {cache_directory}")
def acquire_lock(self):
"""Lock acquisition is not available for filesystem cache backend."""
message = "Shared Kuzu lock requires Redis cache backend."
logger.error(message)
raise SharedKuzuLockRequiresRedisError()
def release_lock(self):
"""Lock release is not available for filesystem cache backend."""
message = "Shared Kuzu lock requires Redis cache backend."
logger.error(message)
raise SharedKuzuLockRequiresRedisError()
async def add_qa(
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = 86400,
):
try:
session_key = f"agent_sessions:{user_id}:{session_id}"
qa_entry = {
"time": datetime.utcnow().isoformat(),
"question": question,
"context": context,
"answer": answer,
}
existing_value = self.cache.get(session_key)
if existing_value is not None:
value: list = json.loads(existing_value)
value.append(qa_entry)
else:
value = [qa_entry]
self.cache.set(session_key, json.dumps(value), expire=ttl)
except Exception as e:
error_msg = f"Unexpected error while adding Q&A to diskcache: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
session_key = f"agent_sessions:{user_id}:{session_id}"
value = self.cache.get(session_key)
if value is None:
return None
entries = json.loads(value)
return entries[-last_n:] if len(entries) > last_n else entries
async def get_all_qas(self, user_id: str, session_id: str):
session_key = f"agent_sessions:{user_id}:{session_id}"
value = self.cache.get(session_key)
if value is None:
return None
return json.loads(value)
async def close(self):
if self.cache is not None:
self.cache.expire()
self.cache.close()
async def main():
adapter = FSCacheAdapter()
session_id = "demo_session"
user_id = "demo_user_id"
print("\nAdding sample Q/A pairs...")
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Basic DB context",
"Redis is an in-memory data store.",
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"Historical context",
"Salvatore Sanfilippo (antirez).",
)
print("\nLatest QA:")
latest = await adapter.get_latest_qa(user_id, session_id)
print(json.dumps(latest, indent=2))
print("\nLast 2 QAs:")
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
print(json.dumps(last_two, indent=2))
session_id = "session_expire_demo"
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Database context",
"Redis is an in-memory data store.",
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"History context",
"Salvatore Sanfilippo (antirez).",
)
print(await adapter.get_all_qas(user_id, session_id))
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,9 +1,11 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
import os
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
from cognee.infrastructure.databases.cache.fscache.FsCacheAdapter import FSCacheAdapter
config = get_cache_config()
@ -33,20 +35,28 @@ def create_cache_engine(
Returns:
--------
- CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic
- CacheDBInterface: An instance of the appropriate cache adapter.
"""
if config.caching:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
return RedisAdapter(
host=cache_host,
port=cache_port,
username=cache_username,
password=cache_password,
lock_name=lock_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
if config.cache_backend == "redis":
return RedisAdapter(
host=cache_host,
port=cache_port,
username=cache_username,
password=cache_password,
lock_name=lock_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
elif config.cache_backend == "fs":
return FSCacheAdapter()
else:
raise ValueError(
f"Unsupported cache backend: '{config.cache_backend}'. "
f"Supported backends are: 'redis', 'fs'"
)
else:
return None

View file

@ -148,3 +148,19 @@ class CacheConnectionError(CogneeConfigurationError):
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
):
super().__init__(message, name, status_code)
class SharedKuzuLockRequiresRedisError(CogneeConfigurationError):
"""
Raised when shared Kuzu locking is requested without configuring the Redis backend.
"""
def __init__(
self,
message: str = (
"Shared Kuzu lock requires Redis cache backend. Configure Redis to enable shared Kuzu locking."
),
name: str = "SharedKuzuLockRequiresRedisError",
status_code: int = status.HTTP_400_BAD_REQUEST,
):
super().__init__(message, name, status_code)

View file

@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
- graph_database_username
- graph_database_password
- graph_database_port
- graph_database_key
- graph_file_path
- graph_model
- graph_topology
@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
graph_database_username: str = ""
graph_database_password: str = ""
graph_database_port: int = 123
graph_database_key: str = ""
graph_file_path: str = ""
graph_filename: str = ""
graph_model: object = KnowledgeGraph
@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
}

View file

@ -33,6 +33,7 @@ def create_graph_engine(
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_key="",
):
"""
Create a graph engine based on the specified provider type.
@ -69,6 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
database_name=graph_database_name,
)
if graph_database_provider == "neo4j":

View file

@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
@abstractmethod
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve nodes and edges filtered by the provided attribute criteria.
Parameters:
-----------
- attribute_filters: A list of dictionaries where keys are attribute names and values
are lists of attribute values to filter by.
"""
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
import time
start_time = time.time()
try:
nodes_query = """
MATCH (n:Node)
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
)
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
Returns:
nodes: List[dict] -> Each dict includes "id" and all node properties
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
if not all(isinstance(x, str) for x in target_ids):
raise CogneeValidationError("target_ids must be a list of strings")
query = """
MATCH (n:Node)-[r]->(m:Node)
WHERE n.id IN $target_ids OR m.id IN $target_ids
RETURN n.id, {
name: n.name,
type: n.type,
properties: n.properties
}, m.id, {
name: m.name,
type: m.type,
properties: m.properties
}, r.relationship_name, r.properties
"""
result = await self.query(query, {"target_ids": target_ids})
if not result:
logger.info("No data returned for the supplied IDs")
return [], []
nodes_dict = {}
edges = []
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
if n_props.get("properties"):
try:
additional_props = json.loads(n_props["properties"])
n_props.update(additional_props)
del n_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {n_id}")
if m_props.get("properties"):
try:
additional_props = json.loads(m_props["properties"])
m_props.update(additional_props)
del m_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {m_id}")
nodes_dict[n_id] = (n_id, n_props)
nodes_dict[m_id] = (m_id, m_props)
edge_props = {}
if r_props_raw:
try:
edge_props = json.loads(r_props_raw)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
source_id = edge_props.get("source_node_id", n_id)
target_id = edge_props.get("target_node_id", m_id)
edges.append((source_id, target_id, r_type, edge_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.

View file

@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
This version uses a single Cypher query for efficiency.
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
query = """
MATCH ()-[r]-()
WHERE startNode(r).id IN $target_ids
OR endNode(r).id IN $target_ids
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
RETURN
properties(a) AS n_properties,
properties(b) AS m_properties,
type(r) AS type,
properties(r) AS properties
"""
result = await self.query(query, {"target_ids": target_ids})
nodes_dict = {}
edges = []
for record in result:
n_props = record["n_properties"]
m_props = record["m_properties"]
r_props = record["properties"]
r_type = record["type"]
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
source_id = r_props.get("source_node_id", n_props["id"])
target_id = r_props.get("target_node_id", m_props["id"])
edges.append((source_id, target_id, r_type, r_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:

View file

@ -416,6 +416,15 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
pass
async def is_empty(self) -> bool:
query = """
MATCH (n)
RETURN true
LIMIT 1;
"""
query_result = await self._client.query(query)
return len(query_result) == 0
@staticmethod
def _get_scored_result(
item: dict, with_vector: bool = False, with_score: bool = False

View file

@ -1,11 +1,15 @@
import os
from uuid import UUID
from typing import Union
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from cognee.modules.data.methods import create_dataset
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
vector_db_name = f"{dataset_id}.lance.db"
graph_db_name = f"{dataset_id}.pkl"
vector_config = get_vectordb_config()
graph_config = get_graph_config()
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
)
try:

View file

@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables:
- vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database.
- vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database.
"""
vector_db_url: str = ""
vector_db_port: int = 1234
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}

View file

@ -1,5 +1,6 @@
from .supported_databases import supported_databases
from .embeddings import get_embedding_engine
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from functools import lru_cache
@ -8,6 +9,7 @@ from functools import lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
):
@ -27,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some
providers.
- vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector').
@ -45,6 +48,7 @@ def create_vector_engine(
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
database_name=vector_db_name,
)
if vector_db_provider.lower() == "pgvector":
@ -133,6 +137,6 @@ def create_vector_engine(
else:
raise EnvironmentError(
f"Unsupported graph database provider: {vector_db_provider}. "
f"Unsupported vector database provider: {vector_db_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
)

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict
@ -18,9 +18,21 @@ class Edge(BaseModel):
# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
# With edge_text for rich embedding representation
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""
weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
edge_text: Optional[str] = None
@field_validator("edge_text", mode="before")
@classmethod
def ensure_edge_text(cls, v, info):
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
if v is None and info.data.get("relationship_type"):
return info.data["relationship_type"]
return v

View file

@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
file_type = Type("text/plain", "txt")
return file_type
if ext in [".csv"]:
file_type = Type("text/csv", "csv")
return file_type
file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding

View file

@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
"""
structured_output_framework: str = "instructor"
llm_instructor_mode: str = ""
llm_provider: str = "openai"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
instance.
"""
return {
"llm_instructor_mode": self.llm_instructor_mode.lower(),
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,

View file

@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic"
model: str
default_instructor_mode = "anthropic_tools"
def __init__(self, max_completion_tokens: int, model: str = None):
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
mode=instructor.Mode.ANTHROPIC_TOOLS,
mode=instructor.Mode(self.instructor_mode),
)
self.model = model

View file

@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
model: str,
api_version: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(
stop=stop_after_delay(128),

View file

@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(
stop=stop_after_delay(128),

View file

@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.ANTHROPIC:
@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
max_completion_tokens=max_completion_tokens,
model=llm_config.llm_model,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.CUSTOM:
@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Custom",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.MISTRAL:
@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else:

View file

@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
model: str
api_key: str
max_completion_tokens: int
default_instructor_mode = "mistral_tools"
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
def __init__(
self,
api_key: str,
model: str,
max_completion_tokens: int,
endpoint: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode.MISTRAL_TOOLS,
mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)

View file

@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
- aclient
"""
default_instructor_mode = "json_mode"
def __init__(
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
self,
endpoint: str,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
):
self.name = name
self.model = model
@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
OpenAI(base_url=self.endpoint, api_key=self.api_key),
mode=instructor.Mode(self.instructor_mode),
)
@retry(

View file

@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
model: str
api_key: str
api_version: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
model: str,
transcription_model: str,
max_completion_tokens: int,
instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
litellm.completion, mode=instructor.Mode(self.instructor_mode)
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)

View file

@ -31,6 +31,7 @@ class LoaderEngine:
"pypdf_loader",
"image_loader",
"audio_loader",
"csv_loader",
"unstructured_loader",
"advanced_pdf_loader",
]

View file

@ -3,5 +3,6 @@
from .text_loader import TextLoader
from .audio_loader import AudioLoader
from .image_loader import ImageLoader
from .csv_loader import CsvLoader
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]

View file

@ -0,0 +1,93 @@
import os
from typing import List
import csv
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
class CsvLoader(LoaderInterface):
"""
Core CSV file loader that handles basic CSV file formats.
"""
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
return [
"csv",
]
@property
def supported_mime_types(self) -> List[str]:
"""Supported MIME types for text content."""
return [
"text/csv",
]
@property
def loader_name(self) -> str:
"""Unique identifier for this loader."""
return "csv_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
"""
Check if this loader can handle the given file.
Args:
extension: File extension
mime_type: Optional MIME type
Returns:
True if file can be handled, False otherwise
"""
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
return True
return False
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
"""
Load and process the csv file.
Args:
file_path: Path to the file to load
encoding: Text encoding to use (default: utf-8)
**kwargs: Additional configuration (unused)
Returns:
LoaderResult containing the file content and metadata
Raises:
FileNotFoundError: If file doesn't exist
UnicodeDecodeError: If file cannot be decoded with specified encoding
OSError: If file cannot be read
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
# Name ingested file of current loader based on original file content hash
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
row_texts = []
row_index = 1
with open(file_path, "r", encoding=encoding, newline="") as file:
reader = csv.DictReader(file)
for row in reader:
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
row_text = ", ".join(pairs)
row_texts.append(f"Row {row_index}:\n{row_text}\n")
row_index += 1
content = "\n".join(row_texts)
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, content)
return full_file_path

View file

@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
@property
def supported_mime_types(self) -> List[str]:
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
return [
"text/plain",
"text/markdown",
"text/csv",
"application/json",
"text/xml",
"application/xml",

View file

@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
if __name__ == "__main__":
loader = AdvancedPdfLoader()
asyncio.run(
loader.load(
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
)
)

View file

@ -1,5 +1,5 @@
from cognee.infrastructure.loaders.external import PyPdfLoader
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
# Registry for loader implementations
supported_loaders = {
@ -7,6 +7,7 @@ supported_loaders = {
TextLoader.loader_name: TextLoader,
ImageLoader.loader_name: ImageLoader,
AudioLoader.loader_name: AudioLoader,
CsvLoader.loader_name: CsvLoader,
}
# Try adding optional loaders

View file

@ -0,0 +1,55 @@
from typing import Optional, List
from cognee import memify
from cognee.context_global_variables import (
set_database_global_context_variables,
set_session_user_context_variable,
)
from cognee.exceptions import CogneeValidationError
from cognee.modules.data.methods import get_authorized_existing_datasets
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.models import User
from cognee.tasks.memify import extract_user_sessions, cognify_session
logger = get_logger("persist_sessions_in_knowledge_graph")
async def persist_sessions_in_knowledge_graph_pipeline(
user: User,
session_ids: Optional[List[str]] = None,
dataset: str = "main_dataset",
run_in_background: bool = False,
):
await set_session_user_context_variable(user)
dataset_to_write = await get_authorized_existing_datasets(
user=user, datasets=[dataset], permission_type="write"
)
if not dataset_to_write:
raise CogneeValidationError(
message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}",
log=False,
)
await set_database_global_context_variables(
dataset_to_write[0].id, dataset_to_write[0].owner_id
)
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
enrichment_tasks = [
Task(cognify_session, dataset_id=dataset_to_write[0].id),
]
result = await memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
dataset=dataset_to_write[0].id,
data=[{}],
run_in_background=run_in_background,
)
logger.info("Session persistence pipeline completed")
return result

View file

@ -0,0 +1,35 @@
from cognee.shared.logging_utils import get_logger
from cognee.tasks.chunks import chunk_by_row
from cognee.modules.chunking.Chunker import Chunker
from .models.DocumentChunk import DocumentChunk
logger = get_logger()
class CsvChunker(Chunker):
async def read(self):
async for content_text in self.get_text():
if content_text is None:
continue
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
if chunk_data["chunk_size"] <= self.max_chunk_size:
yield DocumentChunk(
id=chunk_data["chunk_id"],
text=chunk_data["text"],
chunk_size=chunk_data["chunk_size"],
is_part_of=self.document,
chunk_index=self.chunk_index,
cut_type=chunk_data["cut_type"],
contains=[],
metadata={
"index_fields": ["text"],
},
)
self.chunk_index += 1
else:
raise ValueError(
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
)

View file

@ -1,6 +1,7 @@
from typing import List, Union
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Union[Entity, Event]] = None
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
metadata: dict = {"index_fields": ["text"]}

View file

@ -0,0 +1,124 @@
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from cognee.modules.chunking.Chunker import Chunker
from .models.DocumentChunk import DocumentChunk
logger = get_logger()
class TextChunkerWithOverlap(Chunker):
def __init__(
self,
document,
get_text: callable,
max_chunk_size: int,
chunk_overlap_ratio: float = 0.0,
get_chunk_data: callable = None,
):
super().__init__(document, get_text, max_chunk_size)
self._accumulated_chunk_data = []
self._accumulated_size = 0
self.chunk_overlap_ratio = chunk_overlap_ratio
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
if get_chunk_data is not None:
self.get_chunk_data = get_chunk_data
elif chunk_overlap_ratio > 0:
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, paragraph_max_size, batch_paragraphs=True
)
else:
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, self.max_chunk_size, batch_paragraphs=True
)
def _accumulation_overflows(self, chunk_data):
"""Check if adding chunk_data would exceed max_chunk_size."""
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
def _accumulate_chunk_data(self, chunk_data):
"""Add chunk_data to the current accumulation."""
self._accumulated_chunk_data.append(chunk_data)
self._accumulated_size += chunk_data["chunk_size"]
def _clear_accumulation(self):
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
if self.chunk_overlap == 0:
self._accumulated_chunk_data = []
self._accumulated_size = 0
return
# Keep chunk_data from the end that fit in overlap
overlap_chunk_data = []
overlap_size = 0
for chunk_data in reversed(self._accumulated_chunk_data):
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
overlap_chunk_data.insert(0, chunk_data)
overlap_size += chunk_data["chunk_size"]
else:
break
self._accumulated_chunk_data = overlap_chunk_data
self._accumulated_size = overlap_size
def _create_chunk(self, text, size, cut_type, chunk_id=None):
"""Create a DocumentChunk with standard metadata."""
try:
return DocumentChunk(
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text=text,
chunk_size=size,
is_part_of=self.document,
chunk_index=self.chunk_index,
cut_type=cut_type,
contains=[],
metadata={"index_fields": ["text"]},
)
except Exception as e:
logger.error(e)
raise e
def _create_chunk_from_accumulation(self):
"""Create a DocumentChunk from current accumulated chunk_data."""
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
return self._create_chunk(
text=chunk_text,
size=self._accumulated_size,
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
)
def _emit_chunk(self, chunk_data):
"""Emit a chunk when accumulation overflows."""
if len(self._accumulated_chunk_data) > 0:
chunk = self._create_chunk_from_accumulation()
self._clear_accumulation()
self._accumulate_chunk_data(chunk_data)
else:
# Handle single chunk_data exceeding max_chunk_size
chunk = self._create_chunk(
text=chunk_data["text"],
size=chunk_data["chunk_size"],
cut_type=chunk_data["cut_type"],
chunk_id=chunk_data["chunk_id"],
)
self.chunk_index += 1
return chunk
async def read(self):
async for content_text in self.get_text():
for chunk_data in self.get_chunk_data(content_text):
if not self._accumulation_overflows(chunk_data):
self._accumulate_chunk_data(chunk_data)
continue
yield self._emit_chunk(chunk_data)
if len(self._accumulated_chunk_data) == 0:
return
yield self._create_chunk_from_accumulation()

View file

@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
from .get_data import get_data
from .get_unique_dataset_id import get_unique_dataset_id
from .get_unique_data_id import get_unique_data_id
from .get_authorized_existing_datasets import get_authorized_existing_datasets
from .get_dataset_ids import get_dataset_ids

View file

@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
.options(joinedload(Dataset.data))
.filter(Dataset.name == dataset_name)
.filter(Dataset.owner_id == owner_id)
.filter(Dataset.tenant_id == user.tenant_id)
)
).first()
if dataset is None:
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
dataset.owner_id = owner_id
dataset = Dataset(
id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
)
session.add(dataset)

View file

@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user):
# Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.)
user_datasets = await get_datasets(user.id)
# Filter out non name mentioned datasets
dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets]
dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets]
# Filter out non current tenant datasets
dataset_ids = [
dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id
]
else:
raise DatasetTypeError(
f"One or more of the provided dataset types is not handled: f{datasets}"

View file

@ -0,0 +1,68 @@
from uuid import uuid5, NAMESPACE_OID, UUID
from sqlalchemy import select
from cognee.modules.data.models.Data import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models import User
async def get_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Function returns a unique UUID for data based on data identifier, user id and tenant id.
If data with legacy ID exists, return that ID to maintain compatibility.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
tenant_id: UUID of the tenant for which data is being added
Returns:
UUID: Unique identifier for the data
"""
def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Deprecated function, returns a unique UUID for data based on data identifier and user id.
Needed to support legacy data without tenant information.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
Returns:
UUID: Unique identifier for the data
"""
# return UUID hash of file contents + owner id + tenant_id
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}")
def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Function returns a unique UUID for data based on data identifier, user id and tenant id.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
tenant_id: UUID of the tenant for which data is being added
Returns:
UUID: Unique identifier for the data
"""
# return UUID hash of file contents + owner id + tenant_id
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}")
# Get all possible data_id values
data_id = {
"modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user),
"legacy_data_id": _get_deprecated_unique_data_id(
data_identifier=data_identifier, user=user
),
}
# Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
legacy_data_point = (
await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"]))
).scalar_one_or_none()
if not legacy_data_point:
return data_id["modern_data_id"]
return data_id["legacy_data_id"]

View file

@ -1,9 +1,71 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.modules.users.models import User
from typing import Union
from sqlalchemy import select
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine
async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
"""
Function returns a unique UUID for dataset based on dataset name, user id and tenant id.
If dataset with legacy ID exists, return that ID to maintain compatibility.
Args:
dataset_name: string representing the dataset name
user: User object adding the dataset
tenant_id: UUID of the tenant for which dataset is being added
Returns:
UUID: Unique identifier for the dataset
"""
def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
"""
Legacy function, returns a unique UUID for dataset based on dataset name and user id.
Needed to support legacy datasets without tenant information.
Args:
dataset_name: string representing the dataset name
user: Current User object adding the dataset
Returns:
UUID: Unique identifier for the dataset
"""
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
"""
Returns a unique UUID for dataset based on dataset name, user id and tenant_id.
Args:
dataset_name: string representing the dataset name
user: Current User object adding the dataset
tenant_id: UUID of the tenant for which dataset is being added
Returns:
UUID: Unique identifier for the dataset
"""
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}")
# Get all possible dataset_id values
dataset_id = {
"modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user),
"legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user),
}
# Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
legacy_dataset = (
await session.execute(
select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"])
)
).scalar_one_or_none()
if not legacy_dataset:
return dataset_id["modern_dataset_id"]
return dataset_id["legacy_dataset_id"]

View file

@ -18,6 +18,7 @@ class Dataset(Base):
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
owner_id = Column(UUID, index=True)
tenant_id = Column(UUID, index=True, nullable=True)
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")
@ -36,5 +37,6 @@ class Dataset(Base):
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"ownerId": str(self.owner_id),
"tenantId": str(self.tenant_id),
"data": [data.to_json() for data in self.data],
}

View file

@ -0,0 +1,33 @@
import io
import csv
from typing import Type
from cognee.modules.chunking.Chunker import Chunker
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from .Document import Document
class CsvDocument(Document):
type: str = "csv"
mime_type: str = "text/csv"
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
async def get_text():
async with open_data_file(
self.raw_data_location, mode="r", encoding="utf-8", newline=""
) as file:
content = file.read()
file_like_obj = io.StringIO(content)
reader = csv.DictReader(file_like_obj)
for row in reader:
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
row_text = ", ".join(pairs)
if not row_text.strip():
break
yield row_text
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
async for chunk in chunker.read():
yield chunk

View file

@ -4,3 +4,4 @@ from .TextDocument import TextDocument
from .ImageDocument import ImageDocument
from .AudioDocument import AudioDocument
from .UnstructuredDocument import UnstructuredDocument
from .CsvDocument import CsvDocument

View file

@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
async def _get_nodeset_subgraph(
self,
adapter,
node_type,
node_name,
):
"""Retrieve subgraph based on node type and name."""
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodeset projected from the database."
)
return nodes_data, edges_data
async def _get_full_or_id_filtered_graph(
self,
adapter,
relevant_ids_to_filter,
):
"""Retrieve full or ID-filtered graph with fallback."""
if relevant_ids_to_filter is None:
logger.info("Retrieving full graph.")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
return nodes_data, edges_data
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
logger.info("Retrieving ID-filtered graph from database.")
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
else:
logger.info("Retrieving full graph from database.")
nodes_data, edges_data = await get_graph_data_fn()
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
logger.warning(
"Id filtered graph returned empty, falling back to full graph retrieval."
)
logger.info("Retrieving full graph")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError("Empty graph projected from the database.")
return nodes_data, edges_data
async def _get_filtered_graph(
self,
adapter,
memory_fragment_filter,
):
"""Retrieve graph filtered by attributes."""
logger.info("Retrieving graph filtered by memory fragment")
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
return nodes_data, edges_data
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await self._get_nodeset_subgraph(
adapter, node_type, node_name
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
adapter, relevant_ids_to_filter
)
else:
nodes_data, edges_data = await self._get_filtered_graph(
adapter, memory_fragment_filter
)
import time
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
self.add_node(
Node(
str(node_id),
node_attributes,
dimension=node_dimension,
node_penalty=triplet_distance_penalty,
)
)
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
@ -171,8 +233,10 @@ class CogneeGraph(CogneeAbstractGraph):
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
distance = embedding_map.get(relationship_type, None)
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance

View file

@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
self,
node_id: str,
attributes: Optional[Dict[str, Any]] = None,
dimension: int = 1,
node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)

View file

@ -1,5 +1,6 @@
from typing import Optional
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)
# Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)
edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_node.name}",
f"entity_description: {entity_node.description}",
]
)
data_chunk.contains.append(
(
Edge(
relationship_type="contains",
edge_text=edge_text,
),
entity_node,
)
)
def _process_graph_edges(

View file

@ -1,71 +1,70 @@
import string
from typing import List
from collections import Counter
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
def _get_top_n_frequent_words(
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
) -> str:
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id in nodes:
continue
text = node.attributes.get("text")
if text:
name = _create_title_from_text(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""
Converts retrieved graph edges into a human-readable string format.
"""Converts retrieved graph edges into a human-readable string format."""
nodes = _extract_nodes_from_edges(retrieved_edges)
Parameters:
-----------
- retrieved_edges (list): A list of edges retrieved from the graph.
Returns:
--------
- str: A formatted string representation of the nodes and their connections.
"""
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
stop_words = DEFAULT_STOP_WORDS
import string
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
from collections import Counter
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
"""Creates a title, by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _top_n_words(text, top_n=first_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = _get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
nodes = _get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
connections = []
for edge in retrieved_edges:
source_name = nodes[edge.node1.id]["name"]
target_name = nodes[edge.node2.id]["name"]
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
connection_section = "\n".join(connections)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"

View file

@ -1,11 +1,11 @@
from uuid import uuid5, NAMESPACE_OID
from uuid import UUID
from .data_types import IngestionData
from cognee.modules.users.models import User
from cognee.modules.data.methods import get_unique_data_id
def identify(data: IngestionData, user: User) -> str:
async def identify(data: IngestionData, user: User) -> UUID:
data_content_hash: str = data.get_identifier()
# return UUID hash of file contents + owner id
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
return await get_unique_data_id(data_identifier=data_content_hash, user=user)

View file

@ -2,6 +2,8 @@ import io
import sys
import traceback
import cognee
def wrap_in_async_handler(user_code: str) -> str:
return (
@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
environment["print"] = customPrintFunction
environment["running_loop"] = loop
environment["cognee"] = cognee
try:
exec(code, environment)

View file

@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
from typing import List, Tuple, Dict, Optional, Any, Union
from typing import List, Tuple, Dict, Optional, Any, Union, IO
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
ontology_file: Optional[Union[str, List[str]]] = None,
ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
files_to_load = []
self.graph = None
if ontology_file is not None:
if isinstance(ontology_file, str):
files_to_load = []
file_objects = []
if hasattr(ontology_file, "read"):
file_objects = [ontology_file]
elif isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
files_to_load = ontology_file
if all(hasattr(item, "read") for item in ontology_file):
file_objects = ontology_file
else:
files_to_load = ontology_file
else:
raise ValueError(
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
)
if files_to_load:
self.graph = Graph()
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
if file_objects:
self.graph = Graph()
loaded_objects = []
for file_obj in file_objects:
try:
content = file_obj.read()
self.graph.parse(data=content, format="xml")
loaded_objects.append(file_obj)
logger.info("Ontology loaded successfully from file object")
except Exception as e:
logger.warning("Failed to parse ontology file object: %s", str(e))
if not loaded_objects:
logger.info(
"No valid ontology file objects found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
elif files_to_load:
self.graph = Graph()
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
)
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."
)
else:
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."

View file

@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
data_id = await ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
return None
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generate completion using provided context or fetch new context.
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, Type, List
class BaseRetriever(ABC):
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context."""
pass

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional
from typing import Any, Optional, Type, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates an LLM completion using the context.
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if session_save:
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id,
)
return completion
return [completion]

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Optional, List, Type
from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(
@ -56,7 +60,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
completions based on them.
@ -76,6 +81,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -143,6 +149,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -152,6 +159,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context_text and triplets and completion:

View file

@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
generate_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -66,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -75,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
@ -121,7 +124,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
@ -165,24 +168,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets
async def get_structured_completion(
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
max_iter=4,
response_model: Type = str,
) -> Any:
) -> List[Any]:
"""
Generate structured completion responses based on a user query and contextual information.
Generate completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
@ -192,7 +199,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns:
--------
- Any: The generated structured completion based on the response model.
- List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
@ -228,45 +236,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
wide_search_top_k=self.wide_search_top_k,
triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@ -141,12 +147,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
context = await self.resolve_edges_to_text(triplets)
return context
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Generates a completion using graph connections context based on a query.
@ -188,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -197,6 +209,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:

View file

@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
@ -146,8 +150,12 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates a response using the query and optional context.
@ -159,6 +167,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -186,6 +195,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -194,6 +204,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@ -71,9 +73,11 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project, node_type=node_type, node_name=node_name
)
# Setting wide search limit based on the parameters
non_global_search = node_name is None
wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
projection_time = time.time() - start_time
vector_collection_search_time = time.time() - start_time
logger.info(
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
if wide_search_limit is not None:
relevant_ids_to_filter = list(
{
str(getattr(scored_node, "id"))
for collection_name, score_collection in node_distances.items()
if collection_name != "EdgeType_relationship_name"
and isinstance(score_collection, (list, tuple))
for scored_node in score_collection
if getattr(scored_node, "id", None)
}
)
else:
relevant_ids_to_filter = None
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances

View file

@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion(
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
@ -12,7 +12,7 @@ async def generate_structured_completion(
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -28,26 +28,6 @@ async def generate_structured_completion(
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)
async def summarize_text(
text: str,
system_prompt_path: str = "summarize_search_results.txt",

Some files were not shown because too many files have changed in this diff Show more