diff --git a/.env.template b/.env.template index 7dcd4f346..61853b983 100644 --- a/.env.template +++ b/.env.template @@ -21,6 +21,10 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" +# Instructor's modes determine how structured data is requested from and extracted from LLM responses +# You can change this type (i.e. mode) via this env variable +# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode" +LLM_INSTRUCTOR_MODE="" EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" @@ -169,8 +173,9 @@ REQUIRE_AUTHENTICATION=False # Vector: LanceDB # Graph: KuzuDB # -# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset -ENABLE_BACKEND_ACCESS_CONTROL=False +# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers. +# Disable mode when using not supported graph/vector databases. +ENABLE_BACKEND_ACCESS_CONTROL=True ################################################################################ # ☁️ Cloud Sync Settings diff --git a/.github/actions/cognee_setup/action.yml b/.github/actions/cognee_setup/action.yml index 4017d524b..3f5726015 100644 --- a/.github/actions/cognee_setup/action.yml +++ b/.github/actions/cognee_setup/action.yml @@ -42,3 +42,8 @@ runs: done fi uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS + + - name: Add telemetry identifier for telemetry test and in case telemetry is enabled by accident + shell: bash + run: | + echo "test-machine" > .anon_id diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0e6f74188..be9d219c1 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -6,6 +6,14 @@ Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> +## Acceptance Criteria + + ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml index b7f324310..98ced21dc 100644 --- a/.github/workflows/basic_tests.yml +++ b/.github/workflows/basic_tests.yml @@ -75,6 +75,7 @@ jobs: name: Run Unit Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -104,6 +105,7 @@ jobs: name: Run Integration Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -132,6 +134,7 @@ jobs: name: Run Simple Examples runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -161,6 +164,7 @@ jobs: name: Run Simple Examples BAML runs-on: ubuntu-22.04 env: + ENV: 'dev' STRUCTURED_OUTPUT_FRAMEWORK: "BAML" BAML_LLM_PROVIDER: openai BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }} @@ -198,6 +202,7 @@ jobs: name: Run Basic Graph Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} diff --git a/.github/workflows/cli_tests.yml b/.github/workflows/cli_tests.yml index 958d341ae..d4f8e5ac0 100644 --- a/.github/workflows/cli_tests.yml +++ b/.github/workflows/cli_tests.yml @@ -39,6 +39,7 @@ jobs: name: CLI Unit Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -66,6 +67,7 @@ jobs: name: CLI Integration Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -93,6 +95,7 @@ jobs: name: CLI Functionality Tests runs-on: ubuntu-22.04 env: + ENV: 'dev' LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} diff --git a/.github/workflows/db_examples_tests.yml b/.github/workflows/db_examples_tests.yml index 51ac9a82a..c58bc48ef 100644 --- a/.github/workflows/db_examples_tests.yml +++ b/.github/workflows/db_examples_tests.yml @@ -60,7 +60,7 @@ jobs: - name: Run Neo4j Example env: - ENV: dev + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -95,7 +95,7 @@ jobs: - name: Run Kuzu Example env: - ENV: dev + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -141,7 +141,7 @@ jobs: - name: Run PGVector Example env: - ENV: dev + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 70a4b56e6..3dea2548c 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -226,7 +226,7 @@ jobs: - name: Dependencies already installed run: echo "Dependencies already installed in setup" - - name: Run parallel databases test + - name: Run permissions test env: ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} @@ -239,6 +239,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_permissions.py + test-multi-tenancy: + name: Test multi tenancy with different situations in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run multi tenancy test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_multi_tenancy.py + test-graph-edges: name: Test graph edge ingestion runs-on: ubuntu-22.04 @@ -308,7 +333,7 @@ jobs: python-version: '3.11.x' extra-dependencies: "postgres redis" - - name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres) + - name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres/Redis) env: ENV: dev LLM_MODEL: ${{ secrets.LLM_MODEL }} @@ -321,6 +346,7 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} GRAPH_DATABASE_PROVIDER: 'kuzu' CACHING: true + CACHE_BACKEND: 'redis' SHARED_KUZU_LOCK: true DB_PROVIDER: 'postgres' DB_NAME: 'cognee_db' @@ -386,8 +412,8 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_feedback_enrichment.py - run_conversation_sessions_test: - name: Conversation sessions test + run_conversation_sessions_test_redis: + name: Conversation sessions test (Redis) runs-on: ubuntu-latest defaults: run: @@ -427,7 +453,60 @@ jobs: python-version: '3.11.x' extra-dependencies: "postgres redis" - - name: Run Conversation session tests + - name: Run Conversation session tests (Redis) + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + GRAPH_DATABASE_PROVIDER: 'kuzu' + CACHING: true + CACHE_BACKEND: 'redis' + DB_PROVIDER: 'postgres' + DB_NAME: 'cognee_db' + DB_HOST: '127.0.0.1' + DB_PORT: 5432 + DB_USERNAME: cognee + DB_PASSWORD: cognee + run: uv run python ./cognee/tests/test_conversation_history.py + + run_conversation_sessions_test_fs: + name: Conversation sessions test (FS) + runs-on: ubuntu-latest + defaults: + run: + shell: bash + services: + postgres: + image: pgvector/pgvector:pg17 + env: + POSTGRES_USER: cognee + POSTGRES_PASSWORD: cognee + POSTGRES_DB: cognee_db + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "postgres" + + - name: Run Conversation session tests (FS) env: ENV: dev LLM_MODEL: ${{ secrets.LLM_MODEL }} @@ -440,6 +519,7 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} GRAPH_DATABASE_PROVIDER: 'kuzu' CACHING: true + CACHE_BACKEND: 'fs' DB_PROVIDER: 'postgres' DB_NAME: 'cognee_db' DB_HOST: '127.0.0.1' diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 57bc88157..f7cc278cb 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -21,6 +21,7 @@ jobs: - name: Run Multimedia Example env: + ENV: 'dev' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: uv run python ./examples/python/multimedia_example.py @@ -40,6 +41,7 @@ jobs: - name: Run Evaluation Framework Example env: + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -69,6 +71,7 @@ jobs: - name: Run Descriptive Graph Metrics Example env: + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -99,6 +102,7 @@ jobs: - name: Run Dynamic Steps Tests env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -124,6 +128,7 @@ jobs: - name: Run Temporal Example env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -149,6 +154,7 @@ jobs: - name: Run Ontology Demo Example env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -174,6 +180,7 @@ jobs: - name: Run Agentic Reasoning Example env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -199,6 +206,7 @@ jobs: - name: Run Memify Tests env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -210,6 +218,32 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./examples/python/memify_coding_agent_example.py + test-custom-pipeline: + name: Run Custom Pipeline Example + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Custom Pipeline Example + env: + ENV: 'dev' + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./examples/python/run_custom_pipeline_example.py + test-permissions-example: name: Run Permissions Example runs-on: ubuntu-22.04 @@ -224,6 +258,7 @@ jobs: - name: Run Memify Tests env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} @@ -249,6 +284,7 @@ jobs: - name: Run Docling Test env: + ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} diff --git a/.github/workflows/load_tests.yml b/.github/workflows/load_tests.yml new file mode 100644 index 000000000..f5b64d8ce --- /dev/null +++ b/.github/workflows/load_tests.yml @@ -0,0 +1,70 @@ +name: Load tests + +permissions: + contents: read + +on: + workflow_dispatch: + workflow_call: + secrets: + LLM_MODEL: + required: true + LLM_ENDPOINT: + required: true + LLM_API_KEY: + required: true + LLM_API_VERSION: + required: true + EMBEDDING_MODEL: + required: true + EMBEDDING_ENDPOINT: + required: true + EMBEDDING_API_KEY: + required: true + EMBEDDING_API_VERSION: + required: true + OPENAI_API_KEY: + required: true + AWS_ACCESS_KEY_ID: + required: true + AWS_SECRET_ACCESS_KEY: + required: true + +jobs: + test-load: + name: Test Load + runs-on: ubuntu-22.04 + timeout-minutes: 60 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Verify File Descriptor Limit + run: ulimit -n + + - name: Run Load Test + env: + ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: True + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + STORAGE_BACKEND: s3 + AWS_REGION: eu-west-1 + AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} + run: uv run python ./cognee/tests/test_load.py + + diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml new file mode 100644 index 000000000..6ac3ca515 --- /dev/null +++ b/.github/workflows/release_test.yml @@ -0,0 +1,17 @@ +# Long-running, heavy and resource-consuming tests for release validation +name: Release Test Workflow + +permissions: + contents: read + +on: + workflow_dispatch: + pull_request: + branches: + - main + +jobs: + load-tests: + name: Load Tests + uses: ./.github/workflows/load_tests.yml + secrets: inherit \ No newline at end of file diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml index e3e46dd97..118c1c06c 100644 --- a/.github/workflows/search_db_tests.yml +++ b/.github/workflows/search_db_tests.yml @@ -84,6 +84,7 @@ jobs: GRAPH_DATABASE_PROVIDER: 'neo4j' VECTOR_DB_PROVIDER: 'lancedb' DB_PROVIDER: 'sqlite' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} @@ -135,6 +136,7 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} GRAPH_DATABASE_PROVIDER: 'kuzu' VECTOR_DB_PROVIDER: 'pgvector' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' DB_PROVIDER: 'postgres' DB_NAME: 'cognee_db' DB_HOST: '127.0.0.1' @@ -197,6 +199,7 @@ jobs: GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} + ENABLE_BACKEND_ACCESS_CONTROL: 'false' DB_NAME: cognee_db DB_HOST: 127.0.0.1 DB_PORT: 5432 diff --git a/.github/workflows/test_different_operating_systems.yml b/.github/workflows/test_different_operating_systems.yml index 64f1a14f9..02651b474 100644 --- a/.github/workflows/test_different_operating_systems.yml +++ b/.github/workflows/test_different_operating_systems.yml @@ -10,6 +10,10 @@ on: required: false type: string default: '["3.10.x", "3.12.x", "3.13.x"]' + os: + required: false + type: string + default: '["ubuntu-22.04", "macos-15", "windows-latest"]' secrets: LLM_PROVIDER: required: true @@ -40,10 +44,11 @@ jobs: run-unit-tests: name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ubuntu-22.04, macos-15, windows-latest] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -76,10 +81,11 @@ jobs: run-integration-tests: name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -112,10 +118,11 @@ jobs: run-library-test: name: Library test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -148,10 +155,11 @@ jobs: run-build-test: name: Build test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -177,10 +185,11 @@ jobs: run-soft-deletion-test: name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -214,10 +223,11 @@ jobs: run-hard-deletion-test: name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml index 5c1597a93..be1e354fc 100644 --- a/.github/workflows/test_suites.yml +++ b/.github/workflows/test_suites.yml @@ -1,4 +1,6 @@ name: Test Suites +permissions: + contents: read on: push: @@ -80,12 +82,22 @@ jobs: uses: ./.github/workflows/notebooks_tests.yml secrets: inherit - different-operating-systems-tests: - name: Operating System and Python Tests + different-os-tests-basic: + name: OS and Python Tests Ubuntu needs: [basic-tests, e2e-tests] uses: ./.github/workflows/test_different_operating_systems.yml with: python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]' + os: '["ubuntu-22.04"]' + secrets: inherit + + different-os-tests-extended: + name: OS and Python Tests Extended + needs: [basic-tests, e2e-tests] + uses: ./.github/workflows/test_different_operating_systems.yml + with: + python-versions: '["3.13.x"]' + os: '["macos-15", "windows-latest"]' secrets: inherit # Matrix-based vector database tests @@ -135,7 +147,8 @@ jobs: e2e-tests, graph-db-tests, notebook-tests, - different-operating-systems-tests, + different-os-tests-basic, + different-os-tests-extended, vector-db-tests, example-tests, llm-tests, @@ -155,7 +168,8 @@ jobs: cli-tests, graph-db-tests, notebook-tests, - different-operating-systems-tests, + different-os-tests-basic, + different-os-tests-extended, vector-db-tests, example-tests, db-examples-tests, @@ -176,7 +190,8 @@ jobs: "${{ needs.cli-tests.result }}" == "success" && "${{ needs.graph-db-tests.result }}" == "success" && "${{ needs.notebook-tests.result }}" == "success" && - "${{ needs.different-operating-systems-tests.result }}" == "success" && + "${{ needs.different-os-tests-basic.result }}" == "success" && + "${{ needs.different-os-tests-extended.result }}" == "success" && "${{ needs.vector-db-tests.result }}" == "success" && "${{ needs.example-tests.result }}" == "success" && "${{ needs.db-examples-tests.result }}" == "success" && diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml index 874ef6ea4..2b4a043bf 100644 --- a/.github/workflows/weighted_edges_tests.yml +++ b/.github/workflows/weighted_edges_tests.yml @@ -2,7 +2,7 @@ name: Weighted Edges Tests on: push: - branches: [ main, weighted_edges ] + branches: [ main, dev, weighted_edges ] paths: - 'cognee/modules/graph/utils/get_graph_from_model.py' - 'cognee/infrastructure/engine/models/Edge.py' @@ -10,7 +10,7 @@ on: - 'examples/python/weighted_edges_example.py' - '.github/workflows/weighted_edges_tests.yml' pull_request: - branches: [ main ] + branches: [ main, dev ] paths: - 'cognee/modules/graph/utils/get_graph_from_model.py' - 'cognee/infrastructure/engine/models/Edge.py' @@ -32,7 +32,7 @@ jobs: env: LLM_PROVIDER: openai LLM_MODEL: gpt-5-mini - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} steps: - name: Check out repository @@ -67,14 +67,13 @@ jobs: env: LLM_PROVIDER: openai LLM_MODEL: gpt-5-mini - LLM_ENDPOINT: https://api.openai.com/v1/ - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_ENDPOINT: https://api.openai.com/v1 + LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_API_VERSION: "2024-02-01" - EMBEDDING_PROVIDER: openai - EMBEDDING_MODEL: text-embedding-3-small - EMBEDDING_ENDPOINT: https://api.openai.com/v1/ - EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} - EMBEDDING_API_VERSION: "2024-02-01" + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} steps: - name: Check out repository uses: actions/checkout@v4 @@ -108,14 +107,14 @@ jobs: env: LLM_PROVIDER: openai LLM_MODEL: gpt-5-mini - LLM_ENDPOINT: https://api.openai.com/v1/ - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_ENDPOINT: https://api.openai.com/v1 + LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_API_VERSION: "2024-02-01" - EMBEDDING_PROVIDER: openai - EMBEDDING_MODEL: text-embedding-3-small - EMBEDDING_ENDPOINT: https://api.openai.com/v1/ - EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }} - EMBEDDING_API_VERSION: "2024-02-01" + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + steps: - name: Check out repository uses: actions/checkout@v4 diff --git a/alembic/env.py b/alembic/env.py index 1cbef65f7..8ca09968d 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -87,11 +87,6 @@ db_engine = get_relational_engine() print("Using database:", db_engine.db_uri) -if "sqlite" in db_engine.db_uri: - from cognee.infrastructure.utils.run_sync import run_sync - - run_sync(db_engine.create_database()) - config.set_section_option( config.config_ini_section, "SQLALCHEMY_DATABASE_URI", diff --git a/alembic/versions/211ab850ef3d_add_sync_operations_table.py b/alembic/versions/211ab850ef3d_add_sync_operations_table.py index 370aab1a4..30049b44b 100644 --- a/alembic/versions/211ab850ef3d_add_sync_operations_table.py +++ b/alembic/versions/211ab850ef3d_add_sync_operations_table.py @@ -10,6 +10,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. @@ -26,7 +27,34 @@ def upgrade() -> None: connection = op.get_bind() inspector = sa.inspect(connection) + if op.get_context().dialect.name == "postgresql": + syncstatus_enum = postgresql.ENUM( + "STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus" + ) + syncstatus_enum.create(op.get_bind(), checkfirst=True) + if "sync_operations" not in inspector.get_table_names(): + if op.get_context().dialect.name == "postgresql": + syncstatus = postgresql.ENUM( + "STARTED", + "IN_PROGRESS", + "COMPLETED", + "FAILED", + "CANCELLED", + name="syncstatus", + create_type=False, + ) + else: + syncstatus = sa.Enum( + "STARTED", + "IN_PROGRESS", + "COMPLETED", + "FAILED", + "CANCELLED", + name="syncstatus", + create_type=False, + ) + # Table doesn't exist, create it normally op.create_table( "sync_operations", @@ -34,15 +62,7 @@ def upgrade() -> None: sa.Column("run_id", sa.Text(), nullable=True), sa.Column( "status", - sa.Enum( - "STARTED", - "IN_PROGRESS", - "COMPLETED", - "FAILED", - "CANCELLED", - name="syncstatus", - create_type=False, - ), + syncstatus, nullable=True, ), sa.Column("progress_percentage", sa.Integer(), nullable=True), diff --git a/alembic/versions/482cd6517ce4_add_default_user.py b/alembic/versions/482cd6517ce4_add_default_user.py index d85f0f146..c8a3dc5d5 100644 --- a/alembic/versions/482cd6517ce4_add_default_user.py +++ b/alembic/versions/482cd6517ce4_add_default_user.py @@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2" def upgrade() -> None: - try: - await_only(create_default_user()) - except UserAlreadyExists: - pass # It's fine if the default user already exists + pass def downgrade() -> None: - await_only(delete_user("default_user@example.com")) + pass diff --git a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py new file mode 100644 index 000000000..7e13898ae --- /dev/null +++ b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py @@ -0,0 +1,98 @@ +"""Expand dataset database for multi user + +Revision ID: 76625596c5c3 +Revises: 211ab850ef3d +Create Date: 2025-10-30 12:55:20.239562 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "76625596c5c3" +down_revision: Union[str, None] = "c946955da633" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + vector_database_provider_column = _get_column( + insp, "dataset_database", "vector_database_provider" + ) + if not vector_database_provider_column: + op.add_column( + "dataset_database", + sa.Column( + "vector_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + ) + + graph_database_provider_column = _get_column( + insp, "dataset_database", "graph_database_provider" + ) + if not graph_database_provider_column: + op.add_column( + "dataset_database", + sa.Column( + "graph_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + ) + + vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url") + if not vector_database_url_column: + op.add_column( + "dataset_database", + sa.Column("vector_database_url", sa.String(), unique=False, nullable=True), + ) + + graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url") + if not graph_database_url_column: + op.add_column( + "dataset_database", + sa.Column("graph_database_url", sa.String(), unique=False, nullable=True), + ) + + vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key") + if not vector_database_key_column: + op.add_column( + "dataset_database", + sa.Column("vector_database_key", sa.String(), unique=False, nullable=True), + ) + + graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key") + if not graph_database_key_column: + op.add_column( + "dataset_database", + sa.Column("graph_database_key", sa.String(), unique=False, nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("dataset_database", "vector_database_provider") + op.drop_column("dataset_database", "graph_database_provider") + op.drop_column("dataset_database", "vector_database_url") + op.drop_column("dataset_database", "graph_database_url") + op.drop_column("dataset_database", "vector_database_key") + op.drop_column("dataset_database", "graph_database_key") diff --git a/alembic/versions/8057ae7329c2_initial_migration.py b/alembic/versions/8057ae7329c2_initial_migration.py index aa0ecd4b8..42e9904a8 100644 --- a/alembic/versions/8057ae7329c2_initial_migration.py +++ b/alembic/versions/8057ae7329c2_initial_migration.py @@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - db_engine = get_relational_engine() - # we might want to delete this - await_only(db_engine.create_database()) + pass def downgrade() -> None: - db_engine = get_relational_engine() - await_only(db_engine.delete_database()) + pass diff --git a/alembic/versions/ab7e313804ae_permission_system_rework.py b/alembic/versions/ab7e313804ae_permission_system_rework.py index bd69b9b41..d83f946a6 100644 --- a/alembic/versions/ab7e313804ae_permission_system_rework.py +++ b/alembic/versions/ab7e313804ae_permission_system_rework.py @@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name): ) +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + def upgrade() -> None: conn = op.get_bind() + insp = sa.inspect(conn) - # Recreate ACLs table with default permissions set to datasets instead of documents - op.drop_table("acls") + dataset_id_column = _get_column(insp, "acls", "dataset_id") + if not dataset_id_column: + # Recreate ACLs table with default permissions set to datasets instead of documents + op.drop_table("acls") - acls_table = op.create_table( - "acls", - sa.Column("id", UUID, primary_key=True, default=uuid4), - sa.Column( - "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ), - sa.Column( - "updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc) - ), - sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")), - sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")), - sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")), - ) + acls_table = op.create_table( + "acls", + sa.Column("id", UUID, primary_key=True, default=uuid4), + sa.Column( + "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + ), + sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")), + sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")), + sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")), + ) - # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table - # definition or load what is in the database - dataset_table = _define_dataset_table() - datasets = conn.execute(sa.select(dataset_table)).fetchall() + # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table + # definition or load what is in the database + dataset_table = _define_dataset_table() + datasets = conn.execute(sa.select(dataset_table)).fetchall() - if not datasets: - return + if not datasets: + return - acl_list = [] + acl_list = [] - for dataset in datasets: - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")) + for dataset in datasets: + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read")) + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write")) + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share")) + acl_list.append( + _create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete") + ) - if acl_list: - op.bulk_insert(acls_table, acl_list) + if acl_list: + op.bulk_insert(acls_table, acl_list) def downgrade() -> None: diff --git a/alembic/versions/c946955da633_multi_tenant_support.py b/alembic/versions/c946955da633_multi_tenant_support.py new file mode 100644 index 000000000..d8fccdfbf --- /dev/null +++ b/alembic/versions/c946955da633_multi_tenant_support.py @@ -0,0 +1,137 @@ +"""Multi Tenant Support + +Revision ID: c946955da633 +Revises: 211ab850ef3d +Create Date: 2025-11-04 18:11:09.325158 + +""" + +from typing import Sequence, Union +from datetime import datetime, timezone +from uuid import uuid4 + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "c946955da633" +down_revision: Union[str, None] = "211ab850ef3d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _now(): + return datetime.now(timezone.utc) + + +def _define_user_table() -> sa.Table: + table = sa.Table( + "users", + sa.MetaData(), + sa.Column( + "id", + sa.UUID, + sa.ForeignKey("principals.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ), + sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True), + ) + return table + + +def _define_dataset_table() -> sa.Table: + # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table + # definition or load what is in the database + table = sa.Table( + "datasets", + sa.MetaData(), + sa.Column("id", sa.UUID, primary_key=True, default=uuid4), + sa.Column("name", sa.Text), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + ), + sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True), + sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True), + ) + + return table + + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + dataset = _define_dataset_table() + user = _define_user_table() + + if "user_tenants" not in insp.get_table_names(): + # Define table with all necessary columns including primary key + user_tenants = op.create_table( + "user_tenants", + sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True), + sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True), + sa.Column( + "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ), + ) + + # Get all users with their tenant_id + user_data = conn.execute( + sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None)) + ).fetchall() + + # Insert into user_tenants table + if user_data: + op.bulk_insert( + user_tenants, + [ + {"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()} + for user_id, tenant_id in user_data + ], + ) + + tenant_id_column = _get_column(insp, "datasets", "tenant_id") + if not tenant_id_column: + op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True)) + + # Build subquery, select users.tenant_id for each dataset.owner_id + tenant_id_from_dataset_owner = ( + sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery() + ) + + if op.get_context().dialect.name == "sqlite": + # If column doesn't exist create new original_extension column and update from values of extension column + with op.batch_alter_table("datasets") as batch_op: + batch_op.execute( + dataset.update().values( + tenant_id=tenant_id_from_dataset_owner, + ) + ) + else: + conn = op.get_bind() + conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner)) + + op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"]) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_tenants") + op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets") + op.drop_column("datasets", "tenant_id") + # ### end Alembic commands ### diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index ce6dad88a..4131be988 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -194,7 +194,6 @@ async def cognify( Prerequisites: - **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation) - - **Data Added**: Must have data previously added via `cognee.add()` - **Vector Database**: Must be accessible for embeddings storage - **Graph Database**: Must be accessible for relationship storage @@ -1096,6 +1095,10 @@ async def main(): # Skip migrations when in API mode (the API server handles its own database) if not args.no_migration and not args.api_url: + from cognee.modules.engine.operations.setup import setup + + await setup() + # Run Alembic migrations from the main cognee directory where alembic.ini is located logger.info("Running database migrations...") migration_result = subprocess.run( diff --git a/cognee/__init__.py b/cognee/__init__.py index 6e4d2a903..4d150ce4e 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -19,6 +19,7 @@ from .api.v1.add import add from .api.v1.delete import delete from .api.v1.cognify import cognify from .modules.memify import memify +from .modules.run_custom_pipeline import run_custom_pipeline from .api.v1.update import update from .api.v1.config.config import config from .api.v1.datasets.datasets import datasets diff --git a/cognee/api/client.py b/cognee/api/client.py index 6766c12de..1a08aed56 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router from cognee.api.v1.datasets.routers import get_datasets_router from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router from cognee.api.v1.search.routers import get_search_router +from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router from cognee.api.v1.memify.routers import get_memify_router from cognee.api.v1.add.routers import get_add_router from cognee.api.v1.delete.routers import get_delete_router @@ -39,6 +40,8 @@ from cognee.api.v1.users.routers import ( ) from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION +# Ensure application logging is configured for container stdout/stderr +setup_logging() logger = get_logger() if os.getenv("ENV", "prod") == "prod": @@ -74,6 +77,9 @@ async def lifespan(app: FastAPI): await get_default_user() + # Emit a clear startup message for docker logs + logger.info("Backend server has started") + yield @@ -258,6 +264,8 @@ app.include_router( app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"]) +app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"]) + app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"]) app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"]) diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index b2e7068b0..39dc1a3e6 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -82,7 +82,9 @@ def get_add_router() -> APIRouter: datasetName, user=user, dataset_id=datasetId, - node_set=node_set if node_set else None, + node_set=node_set + if node_set != [""] + else None, # Transform default node_set endpoint value to None ) if isinstance(add_run, PipelineRunErrored): diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 231bbcd11..4f1497e3c 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) + ontology_key: Optional[List[str]] = Field( + default=None, description="Reference to one or more previously uploaded ontologies" + ) def get_cognify_router() -> APIRouter: @@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. + - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter: { "datasets": ["research_papers", "documentation"], "run_in_background": false, - "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections." + "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", + "ontology_key": ["medical_ontology_v1"] } ``` @@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter: ) from cognee.api.v1.cognify import cognify as cognee_cognify + from cognee.api.v1.ontologies.ontologies import OntologyService try: datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets + config_to_use = None + + if payload.ontology_key: + ontology_service = OntologyService() + ontology_contents = ontology_service.get_ontology_contents( + payload.ontology_key, user + ) + + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) + from io import StringIO + + ontology_streams = [StringIO(content) for content in ontology_contents] + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams) + } + } cognify_run = await cognee_cognify( datasets, user, + config=config_to_use, run_in_background=payload.run_in_background, custom_prompt=payload.custom_prompt, ) diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py new file mode 100644 index 000000000..b90d46c3d --- /dev/null +++ b/cognee/api/v1/ontologies/__init__.py @@ -0,0 +1,4 @@ +from .ontologies import OntologyService +from .routers.get_ontology_router import get_ontology_router + +__all__ = ["OntologyService", "get_ontology_router"] diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py new file mode 100644 index 000000000..130b4a862 --- /dev/null +++ b/cognee/api/v1/ontologies/ontologies.py @@ -0,0 +1,183 @@ +import os +import json +import tempfile +from pathlib import Path +from datetime import datetime, timezone +from typing import Optional, List +from dataclasses import dataclass + + +@dataclass +class OntologyMetadata: + ontology_key: str + filename: str + size_bytes: int + uploaded_at: str + description: Optional[str] = None + + +class OntologyService: + def __init__(self): + pass + + @property + def base_dir(self) -> Path: + return Path(tempfile.gettempdir()) / "ontologies" + + def _get_user_dir(self, user_id: str) -> Path: + user_dir = self.base_dir / str(user_id) + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir + + def _get_metadata_path(self, user_dir: Path) -> Path: + return user_dir / "metadata.json" + + def _load_metadata(self, user_dir: Path) -> dict: + metadata_path = self._get_metadata_path(user_dir) + if metadata_path.exists(): + with open(metadata_path, "r") as f: + return json.load(f) + return {} + + def _save_metadata(self, user_dir: Path, metadata: dict): + metadata_path = self._get_metadata_path(user_dir) + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + async def upload_ontology( + self, ontology_key: str, file, user, description: Optional[str] = None + ) -> OntologyMetadata: + if not file.filename.lower().endswith(".owl"): + raise ValueError("File must be in .owl format") + + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + if ontology_key in metadata: + raise ValueError(f"Ontology key '{ontology_key}' already exists") + + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError("File size exceeds 10MB limit") + + file_path = user_dir / f"{ontology_key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": description, + } + metadata[ontology_key] = ontology_metadata + self._save_metadata(user_dir, metadata) + + return OntologyMetadata( + ontology_key=ontology_key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=description, + ) + + async def upload_ontologies( + self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None + ) -> List[OntologyMetadata]: + """ + Upload ontology files with their respective keys. + + Args: + ontology_key: List of unique keys for each ontology + files: List of UploadFile objects (same length as keys) + user: Authenticated user + descriptions: Optional list of descriptions for each file + + Returns: + List of OntologyMetadata objects for uploaded files + + Raises: + ValueError: If keys duplicate, file format invalid, or array lengths don't match + """ + if len(ontology_key) != len(files): + raise ValueError("Number of keys must match number of files") + + if len(set(ontology_key)) != len(ontology_key): + raise ValueError("Duplicate ontology keys not allowed") + + if descriptions and len(descriptions) != len(files): + raise ValueError("Number of descriptions must match number of files") + + results = [] + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + for i, (key, file) in enumerate(zip(ontology_key, files)): + if key in metadata: + raise ValueError(f"Ontology key '{key}' already exists") + + if not file.filename.lower().endswith(".owl"): + raise ValueError(f"File '{file.filename}' must be in .owl format") + + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError(f"File '{file.filename}' exceeds 10MB limit") + + file_path = user_dir / f"{key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": descriptions[i] if descriptions else None, + } + metadata[key] = ontology_metadata + + results.append( + OntologyMetadata( + ontology_key=key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=descriptions[i] if descriptions else None, + ) + ) + + self._save_metadata(user_dir, metadata) + return results + + def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]: + """ + Retrieve ontology content for one or more keys. + + Args: + ontology_key: List of ontology keys to retrieve (can contain single item) + user: Authenticated user + + Returns: + List of ontology content strings + + Raises: + ValueError: If any ontology key not found + """ + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + contents = [] + for key in ontology_key: + if key not in metadata: + raise ValueError(f"Ontology key '{key}' not found") + + file_path = user_dir / f"{key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{key}' not found") + + with open(file_path, "r", encoding="utf-8") as f: + contents.append(f.read()) + return contents + + def list_ontologies(self, user) -> dict: + user_dir = self._get_user_dir(str(user.id)) + return self._load_metadata(user_dir) diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py new file mode 100644 index 000000000..ee31c683f --- /dev/null +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -0,0 +1,107 @@ +from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException +from fastapi.responses import JSONResponse +from typing import Optional, List + +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user +from cognee.shared.utils import send_telemetry +from cognee import __version__ as cognee_version +from ..ontologies import OntologyService + + +def get_ontology_router() -> APIRouter: + router = APIRouter() + ontology_service = OntologyService() + + @router.post("", response_model=dict) + async def upload_ontology( + ontology_key: str = Form(...), + ontology_file: List[UploadFile] = File(...), + descriptions: Optional[str] = Form(None), + user: User = Depends(get_authenticated_user), + ): + """ + Upload ontology files with their respective keys for later use in cognify operations. + + Supports both single and multiple file uploads: + - Single file: ontology_key=["key"], ontology_file=[file] + - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2] + + ## Request Parameters + - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies + - **ontology_file** (List[UploadFile]): OWL format ontology files + - **descriptions** (Optional[str]): JSON array string of optional descriptions + + ## Response + Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. + + ## Error Codes + - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology Upload API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "POST /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + import json + + ontology_keys = json.loads(ontology_key) + description_list = json.loads(descriptions) if descriptions else None + + if not isinstance(ontology_keys, list): + raise ValueError("ontology_key must be a JSON array") + + results = await ontology_service.upload_ontologies( + ontology_keys, ontology_file, user, description_list + ) + + return { + "uploaded_ontologies": [ + { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at, + "description": result.description, + } + for result in results + ] + } + except (json.JSONDecodeError, ValueError) as e: + return JSONResponse(status_code=400, content={"error": str(e)}) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + @router.get("", response_model=dict) + async def list_ontologies(user: User = Depends(get_authenticated_user)): + """ + List all uploaded ontologies for the authenticated user. + + ## Response + Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp. + + ## Error Codes + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology List API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "GET /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + metadata = ontology_service.list_ontologies(user) + return metadata + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + return router diff --git a/cognee/api/v1/permissions/routers/get_permissions_router.py b/cognee/api/v1/permissions/routers/get_permissions_router.py index 565e95732..63de97eaa 100644 --- a/cognee/api/v1/permissions/routers/get_permissions_router.py +++ b/cognee/api/v1/permissions/routers/get_permissions_router.py @@ -1,15 +1,20 @@ from uuid import UUID -from typing import List +from typing import List, Union from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from cognee.modules.users.models import User +from cognee.api.DTO import InDTO from cognee.modules.users.methods import get_authenticated_user from cognee.shared.utils import send_telemetry from cognee import __version__ as cognee_version +class SelectTenantDTO(InDTO): + tenant_id: UUID | None = None + + def get_permissions_router() -> APIRouter: permissions_router = APIRouter() @@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter: status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)} ) + @permissions_router.post("/tenants/select") + async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)): + """ + Select current tenant. + + This endpoint selects a tenant with the specified UUID. Tenants are used + to organize users and resources in multi-tenant environments, providing + isolation and access control between different groups or organizations. + + Sending a null/None value as tenant_id selects his default single user tenant + + ## Request Parameters + - **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant + + ## Response + Returns a success message along with selected tenant id. + """ + send_telemetry( + "Permissions API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}", + "tenant_id": str(payload.tenant_id), + }, + ) + + from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method + + await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id) + + return JSONResponse( + status_code=200, + content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)}, + ) + return permissions_router diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index d4e5fbbe6..354331c57 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -31,6 +31,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -200,6 +202,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return filtered_search_results diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index 16eaf0454..b89c1f70e 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin Processing Pipeline: 1. **Document Classification**: Identifies document types and structures -2. **Permission Validation**: Ensures user has processing rights +2. **Permission Validation**: Ensures user has processing rights 3. **Text Chunking**: Breaks content into semantically meaningful segments 4. **Entity Extraction**: Identifies key concepts, people, places, organizations 5. **Relationship Detection**: Discovers connections between entities @@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge chunker_class = LangchainChunker except ImportError: fmt.warning("LangchainChunker not available, using TextChunker") + elif args.chunker == "CsvChunker": + try: + from cognee.modules.chunking.CsvChunker import CsvChunker + + chunker_class = CsvChunker + except ImportError: + fmt.warning("CsvChunker not available, using TextChunker") result = await cognee.cognify( datasets=datasets, diff --git a/cognee/cli/config.py b/cognee/cli/config.py index d016608c1..082adbaec 100644 --- a/cognee/cli/config.py +++ b/cognee/cli/config.py @@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [ ] # Chunker choices -CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"] +CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"] # Output format choices OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"] diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index d52de4b4e..62e06fc64 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -4,6 +4,8 @@ from typing import Union from uuid import UUID from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector.config import get_vectordb_context_config +from cognee.infrastructure.databases.graph.config import get_graph_context_config from cognee.infrastructure.databases.utils import get_or_create_dataset_database from cognee.infrastructure.files.storage.config import file_storage_config from cognee.modules.users.methods import get_user @@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) +VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] +GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] + async def set_session_user_context_variable(user): session_user.set(user) +def multi_user_support_possible(): + graph_db_config = get_graph_context_config() + vector_db_config = get_vectordb_context_config() + return ( + graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT + and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT + ) + + +def backend_access_control_enabled(): + backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) + if backend_access_control is None: + # If backend access control is not defined in environment variables, + # enable it by default if graph and vector DBs can support it, otherwise disable it + return multi_user_support_possible() + elif backend_access_control.lower() == "true": + # If enabled, ensure that the current graph and vector DBs can support it + multi_user_support = multi_user_support_possible() + if not multi_user_support: + raise EnvironmentError( + "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." + ) + return True + return False + + async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): """ If backend access control is enabled this function will ensure all datasets have their own databases, @@ -38,9 +69,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ """ - base_config = get_base_config() - - if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if not backend_access_control_enabled(): return user = await get_user(user_id) @@ -48,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + base_config = get_base_config() data_root_directory = os.path.join( base_config.data_root_directory, str(user.tenant_id or user.id) ) @@ -57,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # Set vector and graph database configuration based on dataset database information vector_config = { - "vector_db_url": os.path.join( - databases_directory_path, dataset_database.vector_database_name - ), - "vector_db_key": "", - "vector_db_provider": "lancedb", + "vector_db_provider": dataset_database.vector_database_provider, + "vector_db_url": dataset_database.vector_database_url, + "vector_db_key": dataset_database.vector_database_key, + "vector_db_name": dataset_database.vector_database_name, } graph_config = { - "graph_database_provider": "kuzu", + "graph_database_provider": dataset_database.graph_database_provider, + "graph_database_url": dataset_database.graph_database_url, + "graph_database_name": dataset_database.graph_database_name, + "graph_database_key": dataset_database.graph_database_key, "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile new file mode 100644 index 000000000..e83be3da4 --- /dev/null +++ b/cognee/eval_framework/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.11-slim + +# Set environment variables +ENV PIP_NO_CACHE_DIR=true +ENV PATH="${PATH}:/root/.poetry/bin" +ENV PYTHONPATH=/app +ENV SKIP_MIGRATIONS=true + +# System dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + libpq-dev \ + git \ + curl \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY pyproject.toml poetry.lock README.md /app/ + +RUN pip install poetry + +RUN poetry config virtualenvs.create false + +RUN poetry install --extras distributed --extras evals --extras deepeval --no-root + +COPY cognee/ /app/cognee +COPY distributed/ /app/distributed diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6f166657e..29b3ede68 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -35,6 +35,16 @@ class AnswerGeneratorExecutor: retrieval_context = await retriever.get_context(query_text) search_results = await retriever.get_completion(query_text, retrieval_context) + ############ + #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure.. + if isinstance(retrieval_context, list): + retrieval_context = await retriever.convert_retrieved_objects_to_context( + triplets=retrieval_context + ) + + if isinstance(search_results, str): + search_results = [search_results] + ############# answer = { "question": query_text, "answer": search_results[0], diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index d0a2ebe1e..6b55d84b2 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None + params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logger.info("Question answering started...") diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 6edcc0454..9e6f26688 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,7 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' + qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True @@ -25,7 +25,7 @@ class EvalConfig(BaseSettings): "EM", "f1", ] # Use only 'correctness' for DirectLLM - deepeval_model: str = "gpt-5-mini" + deepeval_model: str = "gpt-4o-mini" # Metrics params calculate_metrics: bool = True diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py index aca2686a5..bc2ff77c5 100644 --- a/cognee/eval_framework/modal_run_eval.py +++ b/cognee/eval_framework/modal_run_eval.py @@ -2,7 +2,6 @@ import modal import os import asyncio import datetime -import hashlib import json from cognee.shared.logging_utils import get_logger from cognee.eval_framework.eval_config import EvalConfig @@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b from cognee.eval_framework.answer_generation.run_question_answering_module import ( run_question_answering, ) +import pathlib +from os import path +from modal import Image from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation from cognee.eval_framework.metrics_dashboard import create_dashboard @@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict: app = modal.App("modal-run-eval") -image = ( - modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False) - .copy_local_file("pyproject.toml", "pyproject.toml") - .copy_local_file("poetry.lock", "poetry.lock") - .env( - { - "ENV": os.getenv("ENV"), - "LLM_API_KEY": os.getenv("LLM_API_KEY"), - "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), - } - ) - .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly") +image = Image.from_dockerfile( + path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(), + force_build=False, +).add_local_python_source("cognee") + + +@app.function( + image=image, + max_containers=10, + timeout=86400, + volumes={"/data": vol}, + secrets=[modal.Secret.from_name("eval_secrets")], ) - - -@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol}) async def modal_run_eval(eval_params=None): """Runs evaluation pipeline and returns combined metrics results.""" if eval_params is None: @@ -105,18 +104,7 @@ async def main(): configs = [ EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, - benchmark="HotPotQA", - qa_engine="cognee_graph_completion", - building_corpus_from_scratch=True, - answering_questions=True, - evaluating_answers=True, - calculate_metrics=True, - dashboard=True, - ), - EvalConfig( - task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="TwoWikiMultiHop", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, @@ -127,7 +115,7 @@ async def main(): ), EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="Musique", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, diff --git a/cognee/infrastructure/databases/cache/config.py b/cognee/infrastructure/databases/cache/config.py index 3a28827fe..88ac05885 100644 --- a/cognee/infrastructure/databases/cache/config.py +++ b/cognee/infrastructure/databases/cache/config.py @@ -1,6 +1,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from functools import lru_cache -from typing import Optional +from typing import Optional, Literal class CacheConfig(BaseSettings): @@ -15,6 +15,7 @@ class CacheConfig(BaseSettings): - agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release. """ + cache_backend: Literal["redis", "fs"] = "fs" caching: bool = False shared_kuzu_lock: bool = False cache_host: str = "localhost" @@ -28,6 +29,7 @@ class CacheConfig(BaseSettings): def to_dict(self) -> dict: return { + "cache_backend": self.cache_backend, "caching": self.caching, "shared_kuzu_lock": self.shared_kuzu_lock, "cache_host": self.cache_host, diff --git a/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py new file mode 100644 index 000000000..497e6afec --- /dev/null +++ b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py @@ -0,0 +1,151 @@ +import asyncio +import json +import os +from datetime import datetime +import time +import threading +import diskcache as dc + +from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface +from cognee.infrastructure.databases.exceptions.exceptions import ( + CacheConnectionError, + SharedKuzuLockRequiresRedisError, +) +from cognee.infrastructure.files.storage.get_storage_config import get_storage_config +from cognee.shared.logging_utils import get_logger + +logger = get_logger("FSCacheAdapter") + + +class FSCacheAdapter(CacheDBInterface): + def __init__(self): + default_key = "sessions_db" + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + cache_directory = os.path.join(data_root_directory, ".cognee_fs_cache", default_key) + os.makedirs(cache_directory, exist_ok=True) + self.cache = dc.Cache(directory=cache_directory) + self.cache.expire() + + logger.debug(f"FSCacheAdapter initialized with cache directory: {cache_directory}") + + def acquire_lock(self): + """Lock acquisition is not available for filesystem cache backend.""" + message = "Shared Kuzu lock requires Redis cache backend." + logger.error(message) + raise SharedKuzuLockRequiresRedisError() + + def release_lock(self): + """Lock release is not available for filesystem cache backend.""" + message = "Shared Kuzu lock requires Redis cache backend." + logger.error(message) + raise SharedKuzuLockRequiresRedisError() + + async def add_qa( + self, + user_id: str, + session_id: str, + question: str, + context: str, + answer: str, + ttl: int | None = 86400, + ): + try: + session_key = f"agent_sessions:{user_id}:{session_id}" + + qa_entry = { + "time": datetime.utcnow().isoformat(), + "question": question, + "context": context, + "answer": answer, + } + + existing_value = self.cache.get(session_key) + if existing_value is not None: + value: list = json.loads(existing_value) + value.append(qa_entry) + else: + value = [qa_entry] + + self.cache.set(session_key, json.dumps(value), expire=ttl) + except Exception as e: + error_msg = f"Unexpected error while adding Q&A to diskcache: {str(e)}" + logger.error(error_msg) + raise CacheConnectionError(error_msg) from e + + async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5): + session_key = f"agent_sessions:{user_id}:{session_id}" + value = self.cache.get(session_key) + if value is None: + return None + entries = json.loads(value) + return entries[-last_n:] if len(entries) > last_n else entries + + async def get_all_qas(self, user_id: str, session_id: str): + session_key = f"agent_sessions:{user_id}:{session_id}" + value = self.cache.get(session_key) + if value is None: + return None + return json.loads(value) + + async def close(self): + if self.cache is not None: + self.cache.expire() + self.cache.close() + + +async def main(): + adapter = FSCacheAdapter() + session_id = "demo_session" + user_id = "demo_user_id" + + print("\nAdding sample Q/A pairs...") + await adapter.add_qa( + user_id, + session_id, + "What is Redis?", + "Basic DB context", + "Redis is an in-memory data store.", + ) + await adapter.add_qa( + user_id, + session_id, + "Who created Redis?", + "Historical context", + "Salvatore Sanfilippo (antirez).", + ) + + print("\nLatest QA:") + latest = await adapter.get_latest_qa(user_id, session_id) + print(json.dumps(latest, indent=2)) + + print("\nLast 2 QAs:") + last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2) + print(json.dumps(last_two, indent=2)) + + session_id = "session_expire_demo" + + await adapter.add_qa( + user_id, + session_id, + "What is Redis?", + "Database context", + "Redis is an in-memory data store.", + ) + + await adapter.add_qa( + user_id, + session_id, + "Who created Redis?", + "History context", + "Salvatore Sanfilippo (antirez).", + ) + + print(await adapter.get_all_qas(user_id, session_id)) + + await adapter.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/infrastructure/databases/cache/get_cache_engine.py b/cognee/infrastructure/databases/cache/get_cache_engine.py index c1fa3311c..f70358607 100644 --- a/cognee/infrastructure/databases/cache/get_cache_engine.py +++ b/cognee/infrastructure/databases/cache/get_cache_engine.py @@ -1,9 +1,11 @@ """Factory to get the appropriate cache coordination engine (e.g., Redis).""" from functools import lru_cache +import os from typing import Optional from cognee.infrastructure.databases.cache.config import get_cache_config from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface +from cognee.infrastructure.databases.cache.fscache.FsCacheAdapter import FSCacheAdapter config = get_cache_config() @@ -33,20 +35,28 @@ def create_cache_engine( Returns: -------- - - CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic + - CacheDBInterface: An instance of the appropriate cache adapter. """ if config.caching: from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter - return RedisAdapter( - host=cache_host, - port=cache_port, - username=cache_username, - password=cache_password, - lock_name=lock_key, - timeout=agentic_lock_expire, - blocking_timeout=agentic_lock_timeout, - ) + if config.cache_backend == "redis": + return RedisAdapter( + host=cache_host, + port=cache_port, + username=cache_username, + password=cache_password, + lock_name=lock_key, + timeout=agentic_lock_expire, + blocking_timeout=agentic_lock_timeout, + ) + elif config.cache_backend == "fs": + return FSCacheAdapter() + else: + raise ValueError( + f"Unsupported cache backend: '{config.cache_backend}'. " + f"Supported backends are: 'redis', 'fs'" + ) else: return None diff --git a/cognee/infrastructure/databases/exceptions/exceptions.py b/cognee/infrastructure/databases/exceptions/exceptions.py index 72b13e3a2..d8dd99c17 100644 --- a/cognee/infrastructure/databases/exceptions/exceptions.py +++ b/cognee/infrastructure/databases/exceptions/exceptions.py @@ -148,3 +148,19 @@ class CacheConnectionError(CogneeConfigurationError): status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE, ): super().__init__(message, name, status_code) + + +class SharedKuzuLockRequiresRedisError(CogneeConfigurationError): + """ + Raised when shared Kuzu locking is requested without configuring the Redis backend. + """ + + def __init__( + self, + message: str = ( + "Shared Kuzu lock requires Redis cache backend. Configure Redis to enable shared Kuzu locking." + ), + name: str = "SharedKuzuLockRequiresRedisError", + status_code: int = status.HTTP_400_BAD_REQUEST, + ): + super().__init__(message, name, status_code) diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index b7907313c..23687b359 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -26,6 +26,7 @@ class GraphConfig(BaseSettings): - graph_database_username - graph_database_password - graph_database_port + - graph_database_key - graph_file_path - graph_model - graph_topology @@ -41,6 +42,7 @@ class GraphConfig(BaseSettings): graph_database_username: str = "" graph_database_password: str = "" graph_database_port: int = 123 + graph_database_key: str = "" graph_file_path: str = "" graph_filename: str = "" graph_model: object = KnowledgeGraph @@ -90,6 +92,7 @@ class GraphConfig(BaseSettings): "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, "graph_model": self.graph_model, "graph_topology": self.graph_topology, @@ -116,6 +119,7 @@ class GraphConfig(BaseSettings): "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 1ea61d29f..82e3cad6e 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -33,6 +33,7 @@ def create_graph_engine( graph_database_username="", graph_database_password="", graph_database_port="", + graph_database_key="", ): """ Create a graph engine based on the specified provider type. @@ -69,6 +70,7 @@ def create_graph_engine( graph_database_url=graph_database_url, graph_database_username=graph_database_username, graph_database_password=graph_database_password, + database_name=graph_database_name, ) if graph_database_provider == "neo4j": diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 67df1a27c..8f8c96e79 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -398,3 +398,18 @@ class GraphDBInterface(ABC): - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError + + @abstractmethod + async def get_filtered_graph_data( + self, attribute_filters: List[Dict[str, List[Union[str, int]]]] + ) -> Tuple[List[Node], List[EdgeData]]: + """ + Retrieve nodes and edges filtered by the provided attribute criteria. + + Parameters: + ----------- + + - attribute_filters: A list of dictionaries where keys are attribute names and values + are lists of attribute values to filter by. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 8dd160665..9dbc9c1bc 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor from typing import Dict, Any, List, Union, Optional, Tuple, Type +from cognee.exceptions import CogneeValidationError from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.files.storage import get_file_storage @@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface): A tuple with two elements: a list of tuples of (node_id, properties) and a list of tuples of (source_id, target_id, relationship_name, properties). """ + + import time + + start_time = time.time() + try: nodes_query = """ MATCH (n:Node) @@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface): }, ) ) + + retrieval_time = time.time() - start_time + logger.info( + f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds" + ) return formatted_nodes, formatted_edges except Exception as e: logger.error(f"Failed to get graph data: {e}") @@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface): formatted_edges.append((source_id, target_id, rel_type, props)) return formatted_nodes, formatted_edges + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + Returns: + nodes: List[dict] -> Each dict includes "id" and all node properties + edges: List[dict] -> Each dict includes "source", "target", "type", "properties" + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + if not all(isinstance(x, str) for x in target_ids): + raise CogneeValidationError("target_ids must be a list of strings") + + query = """ + MATCH (n:Node)-[r]->(m:Node) + WHERE n.id IN $target_ids OR m.id IN $target_ids + RETURN n.id, { + name: n.name, + type: n.type, + properties: n.properties + }, m.id, { + name: m.name, + type: m.type, + properties: m.properties + }, r.relationship_name, r.properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + if not result: + logger.info("No data returned for the supplied IDs") + return [], [] + + nodes_dict = {} + edges = [] + + for n_id, n_props, m_id, m_props, r_type, r_props_raw in result: + if n_props.get("properties"): + try: + additional_props = json.loads(n_props["properties"]) + n_props.update(additional_props) + del n_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {n_id}") + + if m_props.get("properties"): + try: + additional_props = json.loads(m_props["properties"]) + m_props.update(additional_props) + del m_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {m_id}") + + nodes_dict[n_id] = (n_id, n_props) + nodes_dict[m_id] = (m_id, m_props) + + edge_props = {} + if r_props_raw: + try: + edge_props = json.loads(r_props_raw) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}") + + source_id = edge_props.get("source_node_id", n_id) + target_id = edge_props.get("target_node_id", m_id) + edges.append((source_id, target_id, r_type, edge_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 6216e107e..f3bb8e173 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface): logger.error(f"Error during graph data retrieval: {str(e)}") raise + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + This version uses a single Cypher query for efficiency. + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + query = """ + MATCH ()-[r]-() + WHERE startNode(r).id IN $target_ids + OR endNode(r).id IN $target_ids + WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b + RETURN + properties(a) AS n_properties, + properties(b) AS m_properties, + type(r) AS type, + properties(r) AS properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + nodes_dict = {} + edges = [] + + for record in result: + n_props = record["n_properties"] + m_props = record["m_properties"] + r_props = record["properties"] + r_type = record["type"] + + nodes_dict[n_props["id"]] = (n_props["id"], n_props) + nodes_dict[m_props["id"]] = (m_props["id"], m_props) + + source_id = r_props.get("source_node_id", n_props["id"]) + target_id = r_props.get("target_node_id", m_props["id"]) + edges.append((source_id, target_id, r_type, r_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py index 5357f3d7c..1e16642b5 100644 --- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -416,6 +416,15 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n") pass + async def is_empty(self) -> bool: + query = """ + MATCH (n) + RETURN true + LIMIT 1; + """ + query_result = await self._client.query(query) + return len(query_result) == 0 + @staticmethod def _get_scored_result( item: dict, with_vector: bool = False, with_score: bool = False diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 29156025d..3684bb100 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -1,11 +1,15 @@ +import os from uuid import UUID from typing import Union from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from cognee.modules.data.methods import create_dataset +from cognee.base_config import get_base_config +from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User @@ -32,8 +36,32 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_db_name = f"{dataset_id}.lance.db" - graph_db_name = f"{dataset_id}.pkl" + vector_config = get_vectordb_config() + graph_config = get_graph_config() + + # Note: for hybrid databases both graph and vector DB name have to be the same + if graph_config.graph_database_provider == "kuzu": + graph_db_name = f"{dataset_id}.pkl" + else: + graph_db_name = f"{dataset_id}" + + if vector_config.vector_db_provider == "lancedb": + vector_db_name = f"{dataset_id}.lance.db" + else: + vector_db_name = f"{dataset_id}" + + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + # Determine vector database URL + if vector_config.vector_db_provider == "lancedb": + vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) + else: + vector_db_url = vector_config.vector_database_url + + # Determine graph database URL async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist @@ -55,6 +83,12 @@ async def get_or_create_dataset_database( dataset_id=dataset_id, vector_database_name=vector_db_name, graph_database_name=graph_db_name, + vector_database_provider=vector_config.vector_db_provider, + graph_database_provider=graph_config.graph_database_provider, + vector_database_url=vector_db_url, + graph_database_url=graph_config.graph_database_url, + vector_database_key=vector_config.vector_db_key, + graph_database_key=graph_config.graph_database_key, ) try: diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index b6d3ae644..7d28f1668 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -18,12 +18,14 @@ class VectorConfig(BaseSettings): Instance variables: - vector_db_url: The URL of the vector database. - vector_db_port: The port for the vector database. + - vector_db_name: The name of the vector database. - vector_db_key: The key for accessing the vector database. - vector_db_provider: The provider for the vector database. """ vector_db_url: str = "" vector_db_port: int = 1234 + vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" @@ -58,6 +60,7 @@ class VectorConfig(BaseSettings): return { "vector_db_url": self.vector_db_url, "vector_db_port": self.vector_db_port, + "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index d1cf855d7..b182f084b 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,5 +1,6 @@ from .supported_databases import supported_databases from .embeddings import get_embedding_engine +from cognee.infrastructure.databases.graph.config import get_graph_context_config from functools import lru_cache @@ -8,6 +9,7 @@ from functools import lru_cache def create_vector_engine( vector_db_provider: str, vector_db_url: str, + vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", ): @@ -27,6 +29,7 @@ def create_vector_engine( - vector_db_url (str): The URL for the vector database instance. - vector_db_port (str): The port for the vector database instance. Required for some providers. + - vector_db_name (str): The name of the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance. - vector_db_provider (str): The name of the vector database provider to use (e.g., 'pgvector'). @@ -45,6 +48,7 @@ def create_vector_engine( url=vector_db_url, api_key=vector_db_key, embedding_engine=embedding_engine, + database_name=vector_db_name, ) if vector_db_provider.lower() == "pgvector": @@ -133,6 +137,6 @@ def create_vector_engine( else: raise EnvironmentError( - f"Unsupported graph database provider: {vector_db_provider}. " + f"Unsupported vector database provider: {vector_db_provider}. " f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}" ) diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py index 5ad9c84dd..59f01a9ab 100644 --- a/cognee/infrastructure/engine/models/Edge.py +++ b/cognee/infrastructure/engine/models/Edge.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from typing import Optional, Any, Dict @@ -18,9 +18,21 @@ class Edge(BaseModel): # Mixed usage has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item]) + + # With edge_text for rich embedding representation + contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity) """ weight: Optional[float] = None weights: Optional[Dict[str, float]] = None relationship_type: Optional[str] = None properties: Optional[Dict[str, Any]] = None + edge_text: Optional[str] = None + + @field_validator("edge_text", mode="before") + @classmethod + def ensure_edge_text(cls, v, info): + """Auto-populate edge_text from relationship_type if not explicitly provided.""" + if v is None and info.data.get("relationship_type"): + return info.data["relationship_type"] + return v diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index 78b20c93d..4bc96fe80 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type file_type = Type("text/plain", "txt") return file_type + if ext in [".csv"]: + file_type = Type("text/csv", "csv") + return file_type + file_type = filetype.guess(file) # If file type could not be determined consider it a plain text file as they don't have magic number encoding diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 8fd196eaf..2e300dc0c 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,6 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" + llm_instructor_mode: str = "" llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -181,6 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { + "llm_instructor_mode": self.llm_instructor_mode.lower(), "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index bf19d6e86..dbf0dfbea 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str + default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None): + def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): import anthropic + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode.ANTHROPIC_TOOLS, + mode=instructor.Mode(self.instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 1187e0cad..226f291d7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface): model: str, api_version: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 8bbbaa2cc..9d7f25fc5 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface): model: str, name: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index c7dcecc56..39558f36d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.ANTHROPIC: @@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, model=llm_config.llm_model + max_completion_tokens=max_completion_tokens, + model=llm_config.llm_model, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.CUSTOM: @@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: @@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, - ) - - elif provider == LLMProvider.MISTRAL: - if llm_config.llm_api_key is None: - raise LLMAPIKeyNotSetError() - - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import ( - MistralAdapter, - ) - - return MistralAdapter( - api_key=llm_config.llm_api_key, - model=llm_config.llm_model, - max_completion_tokens=max_completion_tokens, - endpoint=llm_config.llm_endpoint, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) else: diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 78a3cbff5..355cdae0b 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface): model: str api_key: str max_completion_tokens: int + default_instructor_mode = "mistral_tools" - def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): + def __init__( + self, + api_key: str, + model: str, + max_completion_tokens: int, + endpoint: str = None, + instructor_mode: str = None, + ): from mistralai import Mistral self.model = model self.max_completion_tokens = max_completion_tokens + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode.MISTRAL_TOOLS, + mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index 9c3d185aa..aabd19867 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface): - aclient """ + default_instructor_mode = "json_mode" + def __init__( - self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int + self, + endpoint: str, + api_key: str, + model: str, + name: str, + max_completion_tokens: int, + instructor_mode: str = None, ): self.name = name self.model = model @@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.from_openai( - OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON + OpenAI(base_url=self.endpoint, api_key=self.api_key), + mode=instructor.Mode(self.instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 305b426b8..778c8eec7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface): model: str api_key: str api_version: str + default_instructor_mode = "json_schema_mode" MAX_RETRIES = 5 @@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface): model: str, transcription_model: str, max_completion_tokens: int, + instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode.JSON_SCHEMA + litellm.completion, mode=instructor.Mode(self.instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py index f9511e7c5..4a363a0e6 100644 --- a/cognee/infrastructure/loaders/LoaderEngine.py +++ b/cognee/infrastructure/loaders/LoaderEngine.py @@ -31,6 +31,7 @@ class LoaderEngine: "pypdf_loader", "image_loader", "audio_loader", + "csv_loader", "unstructured_loader", "advanced_pdf_loader", ] diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py index 8a2df80f9..09819fbd2 100644 --- a/cognee/infrastructure/loaders/core/__init__.py +++ b/cognee/infrastructure/loaders/core/__init__.py @@ -3,5 +3,6 @@ from .text_loader import TextLoader from .audio_loader import AudioLoader from .image_loader import ImageLoader +from .csv_loader import CsvLoader -__all__ = ["TextLoader", "AudioLoader", "ImageLoader"] +__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"] diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py new file mode 100644 index 000000000..a314a7a24 --- /dev/null +++ b/cognee/infrastructure/loaders/core/csv_loader.py @@ -0,0 +1,93 @@ +import os +from typing import List +import csv +from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface +from cognee.infrastructure.files.storage import get_file_storage, get_storage_config +from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata + + +class CsvLoader(LoaderInterface): + """ + Core CSV file loader that handles basic CSV file formats. + """ + + @property + def supported_extensions(self) -> List[str]: + """Supported text file extensions.""" + return [ + "csv", + ] + + @property + def supported_mime_types(self) -> List[str]: + """Supported MIME types for text content.""" + return [ + "text/csv", + ] + + @property + def loader_name(self) -> str: + """Unique identifier for this loader.""" + return "csv_loader" + + def can_handle(self, extension: str, mime_type: str) -> bool: + """ + Check if this loader can handle the given file. + + Args: + extension: File extension + mime_type: Optional MIME type + + Returns: + True if file can be handled, False otherwise + """ + if extension in self.supported_extensions and mime_type in self.supported_mime_types: + return True + + return False + + async def load(self, file_path: str, encoding: str = "utf-8", **kwargs): + """ + Load and process the csv file. + + Args: + file_path: Path to the file to load + encoding: Text encoding to use (default: utf-8) + **kwargs: Additional configuration (unused) + + Returns: + LoaderResult containing the file content and metadata + + Raises: + FileNotFoundError: If file doesn't exist + UnicodeDecodeError: If file cannot be decoded with specified encoding + OSError: If file cannot be read + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_metadata = await get_file_metadata(f) + # Name ingested file of current loader based on original file content hash + storage_file_name = "text_" + file_metadata["content_hash"] + ".txt" + + row_texts = [] + row_index = 1 + + with open(file_path, "r", encoding=encoding, newline="") as file: + reader = csv.DictReader(file) + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + row_texts.append(f"Row {row_index}:\n{row_text}\n") + row_index += 1 + + content = "\n".join(row_texts) + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + storage = get_file_storage(data_root_directory) + + full_file_path = await storage.store(storage_file_name, content) + + return full_file_path diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py index a6f94be9b..e478edb22 100644 --- a/cognee/infrastructure/loaders/core/text_loader.py +++ b/cognee/infrastructure/loaders/core/text_loader.py @@ -16,7 +16,7 @@ class TextLoader(LoaderInterface): @property def supported_extensions(self) -> List[str]: """Supported text file extensions.""" - return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"] + return ["txt", "md", "json", "xml", "yaml", "yml", "log"] @property def supported_mime_types(self) -> List[str]: @@ -24,7 +24,6 @@ class TextLoader(LoaderInterface): return [ "text/plain", "text/markdown", - "text/csv", "application/json", "text/xml", "application/xml", diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py index 6d1412b77..4b3ba296a 100644 --- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py +++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py @@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface): if value is None: return "" return str(value).replace("\xa0", " ").strip() - - -if __name__ == "__main__": - loader = AdvancedPdfLoader() - asyncio.run( - loader.load( - "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf" - ) - ) diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py index 156253b53..2b8c3e0b4 100644 --- a/cognee/infrastructure/loaders/supported_loaders.py +++ b/cognee/infrastructure/loaders/supported_loaders.py @@ -1,5 +1,5 @@ from cognee.infrastructure.loaders.external import PyPdfLoader -from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader +from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader # Registry for loader implementations supported_loaders = { @@ -7,6 +7,7 @@ supported_loaders = { TextLoader.loader_name: TextLoader, ImageLoader.loader_name: ImageLoader, AudioLoader.loader_name: AudioLoader, + CsvLoader.loader_name: CsvLoader, } # Try adding optional loaders diff --git a/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py new file mode 100644 index 000000000..92d64c156 --- /dev/null +++ b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py @@ -0,0 +1,55 @@ +from typing import Optional, List + +from cognee import memify +from cognee.context_global_variables import ( + set_database_global_context_variables, + set_session_user_context_variable, +) +from cognee.exceptions import CogneeValidationError +from cognee.modules.data.methods import get_authorized_existing_datasets +from cognee.shared.logging_utils import get_logger +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.users.models import User +from cognee.tasks.memify import extract_user_sessions, cognify_session + + +logger = get_logger("persist_sessions_in_knowledge_graph") + + +async def persist_sessions_in_knowledge_graph_pipeline( + user: User, + session_ids: Optional[List[str]] = None, + dataset: str = "main_dataset", + run_in_background: bool = False, +): + await set_session_user_context_variable(user) + dataset_to_write = await get_authorized_existing_datasets( + user=user, datasets=[dataset], permission_type="write" + ) + + if not dataset_to_write: + raise CogneeValidationError( + message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}", + log=False, + ) + + await set_database_global_context_variables( + dataset_to_write[0].id, dataset_to_write[0].owner_id + ) + + extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)] + + enrichment_tasks = [ + Task(cognify_session, dataset_id=dataset_to_write[0].id), + ] + + result = await memify( + extraction_tasks=extraction_tasks, + enrichment_tasks=enrichment_tasks, + dataset=dataset_to_write[0].id, + data=[{}], + run_in_background=run_in_background, + ) + + logger.info("Session persistence pipeline completed") + return result diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py new file mode 100644 index 000000000..4ba4a969e --- /dev/null +++ b/cognee/modules/chunking/CsvChunker.py @@ -0,0 +1,35 @@ +from cognee.shared.logging_utils import get_logger + + +from cognee.tasks.chunks import chunk_by_row +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class CsvChunker(Chunker): + async def read(self): + async for content_text in self.get_text(): + if content_text is None: + continue + + for chunk_data in chunk_by_row(content_text, self.max_chunk_size): + if chunk_data["chunk_size"] <= self.max_chunk_size: + yield DocumentChunk( + id=chunk_data["chunk_id"], + text=chunk_data["text"], + chunk_size=chunk_data["chunk_size"], + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=chunk_data["cut_type"], + contains=[], + metadata={ + "index_fields": ["text"], + }, + ) + self.chunk_index += 1 + else: + raise ValueError( + f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}" + ) diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 9f8c57486..e024bf00b 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,6 +1,7 @@ from typing import List, Union from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity from cognee.tasks.temporal_graph.models import Event @@ -31,6 +32,6 @@ class DocumentChunk(DataPoint): chunk_index: int cut_type: str is_part_of: Document - contains: List[Union[Entity, Event]] = None + contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/modules/chunking/text_chunker_with_overlap.py b/cognee/modules/chunking/text_chunker_with_overlap.py new file mode 100644 index 000000000..4b9c23079 --- /dev/null +++ b/cognee/modules/chunking/text_chunker_with_overlap.py @@ -0,0 +1,124 @@ +from cognee.shared.logging_utils import get_logger +from uuid import NAMESPACE_OID, uuid5 + +from cognee.tasks.chunks import chunk_by_paragraph +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class TextChunkerWithOverlap(Chunker): + def __init__( + self, + document, + get_text: callable, + max_chunk_size: int, + chunk_overlap_ratio: float = 0.0, + get_chunk_data: callable = None, + ): + super().__init__(document, get_text, max_chunk_size) + self._accumulated_chunk_data = [] + self._accumulated_size = 0 + self.chunk_overlap_ratio = chunk_overlap_ratio + self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio) + + if get_chunk_data is not None: + self.get_chunk_data = get_chunk_data + elif chunk_overlap_ratio > 0: + paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size) + self.get_chunk_data = lambda text: chunk_by_paragraph( + text, paragraph_max_size, batch_paragraphs=True + ) + else: + self.get_chunk_data = lambda text: chunk_by_paragraph( + text, self.max_chunk_size, batch_paragraphs=True + ) + + def _accumulation_overflows(self, chunk_data): + """Check if adding chunk_data would exceed max_chunk_size.""" + return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size + + def _accumulate_chunk_data(self, chunk_data): + """Add chunk_data to the current accumulation.""" + self._accumulated_chunk_data.append(chunk_data) + self._accumulated_size += chunk_data["chunk_size"] + + def _clear_accumulation(self): + """Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio.""" + if self.chunk_overlap == 0: + self._accumulated_chunk_data = [] + self._accumulated_size = 0 + return + + # Keep chunk_data from the end that fit in overlap + overlap_chunk_data = [] + overlap_size = 0 + + for chunk_data in reversed(self._accumulated_chunk_data): + if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap: + overlap_chunk_data.insert(0, chunk_data) + overlap_size += chunk_data["chunk_size"] + else: + break + + self._accumulated_chunk_data = overlap_chunk_data + self._accumulated_size = overlap_size + + def _create_chunk(self, text, size, cut_type, chunk_id=None): + """Create a DocumentChunk with standard metadata.""" + try: + return DocumentChunk( + id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), + text=text, + chunk_size=size, + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=cut_type, + contains=[], + metadata={"index_fields": ["text"]}, + ) + except Exception as e: + logger.error(e) + raise e + + def _create_chunk_from_accumulation(self): + """Create a DocumentChunk from current accumulated chunk_data.""" + chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data) + return self._create_chunk( + text=chunk_text, + size=self._accumulated_size, + cut_type=self._accumulated_chunk_data[-1]["cut_type"], + ) + + def _emit_chunk(self, chunk_data): + """Emit a chunk when accumulation overflows.""" + if len(self._accumulated_chunk_data) > 0: + chunk = self._create_chunk_from_accumulation() + self._clear_accumulation() + self._accumulate_chunk_data(chunk_data) + else: + # Handle single chunk_data exceeding max_chunk_size + chunk = self._create_chunk( + text=chunk_data["text"], + size=chunk_data["chunk_size"], + cut_type=chunk_data["cut_type"], + chunk_id=chunk_data["chunk_id"], + ) + + self.chunk_index += 1 + return chunk + + async def read(self): + async for content_text in self.get_text(): + for chunk_data in self.get_chunk_data(content_text): + if not self._accumulation_overflows(chunk_data): + self._accumulate_chunk_data(chunk_data) + continue + + yield self._emit_chunk(chunk_data) + + if len(self._accumulated_chunk_data) == 0: + return + + yield self._create_chunk_from_accumulation() diff --git a/cognee/modules/data/methods/__init__.py b/cognee/modules/data/methods/__init__.py index 83913085c..7936a9afd 100644 --- a/cognee/modules/data/methods/__init__.py +++ b/cognee/modules/data/methods/__init__.py @@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset from .get_authorized_dataset_by_name import get_authorized_dataset_by_name from .get_data import get_data from .get_unique_dataset_id import get_unique_dataset_id +from .get_unique_data_id import get_unique_data_id from .get_authorized_existing_datasets import get_authorized_existing_datasets from .get_dataset_ids import get_dataset_ids diff --git a/cognee/modules/data/methods/create_dataset.py b/cognee/modules/data/methods/create_dataset.py index c080de0e8..7e28a8255 100644 --- a/cognee/modules/data/methods/create_dataset.py +++ b/cognee/modules/data/methods/create_dataset.py @@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) - .options(joinedload(Dataset.data)) .filter(Dataset.name == dataset_name) .filter(Dataset.owner_id == owner_id) + .filter(Dataset.tenant_id == user.tenant_id) ) ).first() if dataset is None: # Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user) - dataset = Dataset(id=dataset_id, name=dataset_name, data=[]) - dataset.owner_id = owner_id + dataset = Dataset( + id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id + ) session.add(dataset) diff --git a/cognee/modules/data/methods/get_dataset_ids.py b/cognee/modules/data/methods/get_dataset_ids.py index d4402ff36..a61e85310 100644 --- a/cognee/modules/data/methods/get_dataset_ids.py +++ b/cognee/modules/data/methods/get_dataset_ids.py @@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user): # Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.) user_datasets = await get_datasets(user.id) # Filter out non name mentioned datasets - dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets] + dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets] + # Filter out non current tenant datasets + dataset_ids = [ + dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id + ] else: raise DatasetTypeError( f"One or more of the provided dataset types is not handled: f{datasets}" diff --git a/cognee/modules/data/methods/get_unique_data_id.py b/cognee/modules/data/methods/get_unique_data_id.py new file mode 100644 index 000000000..877b5930c --- /dev/null +++ b/cognee/modules/data/methods/get_unique_data_id.py @@ -0,0 +1,68 @@ +from uuid import uuid5, NAMESPACE_OID, UUID +from sqlalchemy import select + +from cognee.modules.data.models.Data import Data +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.models import User + + +async def get_unique_data_id(data_identifier: str, user: User) -> UUID: + """ + Function returns a unique UUID for data based on data identifier, user id and tenant id. + If data with legacy ID exists, return that ID to maintain compatibility. + + Args: + data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.) + user: User object adding the data + tenant_id: UUID of the tenant for which data is being added + + Returns: + UUID: Unique identifier for the data + """ + + def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID: + """ + Deprecated function, returns a unique UUID for data based on data identifier and user id. + Needed to support legacy data without tenant information. + Args: + data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.) + user: User object adding the data + + Returns: + UUID: Unique identifier for the data + """ + # return UUID hash of file contents + owner id + tenant_id + return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}") + + def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID: + """ + Function returns a unique UUID for data based on data identifier, user id and tenant id. + Args: + data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.) + user: User object adding the data + tenant_id: UUID of the tenant for which data is being added + + Returns: + UUID: Unique identifier for the data + """ + # return UUID hash of file contents + owner id + tenant_id + return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}") + + # Get all possible data_id values + data_id = { + "modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user), + "legacy_data_id": _get_deprecated_unique_data_id( + data_identifier=data_identifier, user=user + ), + } + + # Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + legacy_data_point = ( + await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"])) + ).scalar_one_or_none() + + if not legacy_data_point: + return data_id["modern_data_id"] + return data_id["legacy_data_id"] diff --git a/cognee/modules/data/methods/get_unique_dataset_id.py b/cognee/modules/data/methods/get_unique_dataset_id.py index 2caf5fb55..2b765ec78 100644 --- a/cognee/modules/data/methods/get_unique_dataset_id.py +++ b/cognee/modules/data/methods/get_unique_dataset_id.py @@ -1,9 +1,71 @@ from uuid import UUID, uuid5, NAMESPACE_OID -from cognee.modules.users.models import User from typing import Union +from sqlalchemy import select + +from cognee.modules.data.models.Dataset import Dataset +from cognee.modules.users.models import User +from cognee.infrastructure.databases.relational import get_relational_engine async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID: - if isinstance(dataset_name, UUID): - return dataset_name - return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}") + """ + Function returns a unique UUID for dataset based on dataset name, user id and tenant id. + If dataset with legacy ID exists, return that ID to maintain compatibility. + + Args: + dataset_name: string representing the dataset name + user: User object adding the dataset + tenant_id: UUID of the tenant for which dataset is being added + + Returns: + UUID: Unique identifier for the dataset + """ + + def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID: + """ + Legacy function, returns a unique UUID for dataset based on dataset name and user id. + Needed to support legacy datasets without tenant information. + Args: + dataset_name: string representing the dataset name + user: Current User object adding the dataset + + Returns: + UUID: Unique identifier for the dataset + """ + if isinstance(dataset_name, UUID): + return dataset_name + return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}") + + def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID: + """ + Returns a unique UUID for dataset based on dataset name, user id and tenant_id. + Args: + dataset_name: string representing the dataset name + user: Current User object adding the dataset + tenant_id: UUID of the tenant for which dataset is being added + + Returns: + UUID: Unique identifier for the dataset + """ + if isinstance(dataset_name, UUID): + return dataset_name + return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}") + + # Get all possible dataset_id values + dataset_id = { + "modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user), + "legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user), + } + + # Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + legacy_dataset = ( + await session.execute( + select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"]) + ) + ).scalar_one_or_none() + + if not legacy_dataset: + return dataset_id["modern_dataset_id"] + return dataset_id["legacy_dataset_id"] diff --git a/cognee/modules/data/models/Dataset.py b/cognee/modules/data/models/Dataset.py index 797401d5a..fba065253 100644 --- a/cognee/modules/data/models/Dataset.py +++ b/cognee/modules/data/models/Dataset.py @@ -18,6 +18,7 @@ class Dataset(Base): updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) owner_id = Column(UUID, index=True) + tenant_id = Column(UUID, index=True, nullable=True) acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan") @@ -36,5 +37,6 @@ class Dataset(Base): "createdAt": self.created_at.isoformat(), "updatedAt": self.updated_at.isoformat() if self.updated_at else None, "ownerId": str(self.owner_id), + "tenantId": str(self.tenant_id), "data": [data.to_json() for data in self.data], } diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py new file mode 100644 index 000000000..3381275bd --- /dev/null +++ b/cognee/modules/data/processing/document_types/CsvDocument.py @@ -0,0 +1,33 @@ +import io +import csv +from typing import Type + +from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from .Document import Document + + +class CsvDocument(Document): + type: str = "csv" + mime_type: str = "text/csv" + + async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int): + async def get_text(): + async with open_data_file( + self.raw_data_location, mode="r", encoding="utf-8", newline="" + ) as file: + content = file.read() + file_like_obj = io.StringIO(content) + reader = csv.DictReader(file_like_obj) + + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + if not row_text.strip(): + break + yield row_text + + chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text) + + async for chunk in chunker.read(): + yield chunk diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index 2e862f4ba..133dd53f8 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -4,3 +4,4 @@ from .TextDocument import TextDocument from .ImageDocument import ImageDocument from .AudioDocument import AudioDocument from .UnstructuredDocument import UnstructuredDocument +from .CsvDocument import CsvDocument diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 9703928f0..2e0b82e8d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + async def _get_nodeset_subgraph( + self, + adapter, + node_type, + node_name, + ): + """Retrieve subgraph based on node type and name.""" + logger.info("Retrieving graph filtered by node type and node name (NodeSet).") + nodes_data, edges_data = await adapter.get_nodeset_subgraph( + node_type=node_type, node_name=node_name + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Nodeset does not exist, or empty nodeset projected from the database." + ) + return nodes_data, edges_data + + async def _get_full_or_id_filtered_graph( + self, + adapter, + relevant_ids_to_filter, + ): + """Retrieve full or ID-filtered graph with fallback.""" + if relevant_ids_to_filter is None: + logger.info("Retrieving full graph.") + nodes_data, edges_data = await adapter.get_graph_data() + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty graph projected from the database.") + return nodes_data, edges_data + + get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data) + if getattr(adapter.__class__, "get_id_filtered_graph_data", None): + logger.info("Retrieving ID-filtered graph from database.") + nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter) + else: + logger.info("Retrieving full graph from database.") + nodes_data, edges_data = await get_graph_data_fn() + if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data): + logger.warning( + "Id filtered graph returned empty, falling back to full graph retrieval." + ) + logger.info("Retrieving full graph") + nodes_data, edges_data = await adapter.get_graph_data() + + if not nodes_data or not edges_data: + raise EntityNotFoundError("Empty graph projected from the database.") + return nodes_data, edges_data + + async def _get_filtered_graph( + self, + adapter, + memory_fragment_filter, + ): + """Retrieve graph filtered by attributes.""" + logger.info("Retrieving graph filtered by memory fragment") + nodes_data, edges_data = await adapter.get_filtered_graph_data( + attribute_filters=memory_fragment_filter + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty filtered graph projected from the database.") + return nodes_data, edges_data + async def project_graph_from_db( self, adapter: Union[GraphDBInterface], @@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph): memory_fragment_filter=[], node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: float = 3.5, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidDimensionsError() try: + if node_type is not None and node_name not in [None, [], ""]: + nodes_data, edges_data = await self._get_nodeset_subgraph( + adapter, node_type, node_name + ) + elif len(memory_fragment_filter) == 0: + nodes_data, edges_data = await self._get_full_or_id_filtered_graph( + adapter, relevant_ids_to_filter + ) + else: + nodes_data, edges_data = await self._get_filtered_graph( + adapter, memory_fragment_filter + ) + import time start_time = time.time() - - # Determine projection strategy - if node_type is not None and node_name not in [None, [], ""]: - nodes_data, edges_data = await adapter.get_nodeset_subgraph( - node_type=node_type, node_name=node_name - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Nodeset does not exist, or empty nodetes projected from the database." - ) - elif len(memory_fragment_filter) == 0: - nodes_data, edges_data = await adapter.get_graph_data() - if not nodes_data or not edges_data: - raise EntityNotFoundError(message="Empty graph projected from the database.") - else: - nodes_data, edges_data = await adapter.get_filtered_graph_data( - attribute_filters=memory_fragment_filter - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Empty filtered graph projected from the database." - ) - # Process nodes for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} - self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension)) + self.add_node( + Node( + str(node_id), + node_attributes, + dimension=node_dimension, + node_penalty=triplet_distance_penalty, + ) + ) # Process edges for source_id, target_id, relationship_type, properties in edges_data: @@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph): attributes=edge_attributes, directed=directed, dimension=edge_dimension, + edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) @@ -171,8 +233,10 @@ class CogneeGraph(CogneeAbstractGraph): embedding_map = {result.payload["text"]: result.score for result in edge_distances} for edge in self.edges: - relationship_type = edge.attributes.get("relationship_type") - distance = embedding_map.get(relationship_type, None) + edge_key = edge.attributes.get("edge_text") or edge.attributes.get( + "relationship_type" + ) + distance = embedding_map.get(edge_key, None) if distance is not None: edge.attributes["vector_distance"] = distance diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 0ca9c4fb9..62ef8d9fd 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -20,13 +20,17 @@ class Node: status: np.ndarray def __init__( - self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1 + self, + node_id: str, + attributes: Optional[Dict[str, Any]] = None, + dimension: int = 1, + node_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = node_penalty self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -105,13 +109,14 @@ class Edge: attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1, + edge_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = edge_penalty self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index 3b01f5af4..c68eb494d 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -1,5 +1,6 @@ from typing import Optional +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.chunking.models import DocumentChunk from cognee.modules.engine.models import Entity, EntityType from cognee.modules.engine.utils import ( @@ -243,10 +244,26 @@ def _process_graph_nodes( ontology_relationships, ) - # Add entity to data chunk if data_chunk.contains is None: data_chunk.contains = [] - data_chunk.contains.append(entity_node) + + edge_text = "; ".join( + [ + "relationship_name: contains", + f"entity_name: {entity_node.name}", + f"entity_description: {entity_node.description}", + ] + ) + + data_chunk.contains.append( + ( + Edge( + relationship_type="contains", + edge_text=edge_text, + ), + entity_node, + ) + ) def _process_graph_edges( diff --git a/cognee/modules/graph/utils/resolve_edges_to_text.py b/cognee/modules/graph/utils/resolve_edges_to_text.py index eb5bedd2c..5deb13ba8 100644 --- a/cognee/modules/graph/utils/resolve_edges_to_text.py +++ b/cognee/modules/graph/utils/resolve_edges_to_text.py @@ -1,71 +1,70 @@ +import string from typing import List +from collections import Counter + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS + + +def _get_top_n_frequent_words( + text: str, stop_words: set = None, top_n: int = 3, separator: str = ", " +) -> str: + """Concatenates the top N frequent words in text.""" + if stop_words is None: + stop_words = DEFAULT_STOP_WORDS + + words = [word.lower().strip(string.punctuation) for word in text.split()] + words = [word for word in words if word and word not in stop_words] + + top_words = [word for word, freq in Counter(words).most_common(top_n)] + return separator.join(top_words) + + +def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: + """Creates a title by combining first words with most frequent words from the text.""" + first_words = text.split()[:first_n_words] + top_words = _get_top_n_frequent_words(text, top_n=top_n_words) + return f"{' '.join(first_words)}... [{top_words}]" + + +def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict: + """Creates a dictionary of nodes with their names and content.""" + nodes = {} + + for edge in retrieved_edges: + for node in (edge.node1, edge.node2): + if node.id in nodes: + continue + + text = node.attributes.get("text") + if text: + name = _create_title_from_text(text) + content = text + else: + name = node.attributes.get("name", "Unnamed Node") + content = node.attributes.get("description", name) + + nodes[node.id] = {"node": node, "name": name, "content": content} + + return nodes async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str: - """ - Converts retrieved graph edges into a human-readable string format. + """Converts retrieved graph edges into a human-readable string format.""" + nodes = _extract_nodes_from_edges(retrieved_edges) - Parameters: - ----------- - - - retrieved_edges (list): A list of edges retrieved from the graph. - - Returns: - -------- - - - str: A formatted string representation of the nodes and their connections. - """ - - def _get_nodes(retrieved_edges: List[Edge]) -> dict: - def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: - def _top_n_words(text, stop_words=None, top_n=3, separator=", "): - """Concatenates the top N frequent words in text.""" - if stop_words is None: - from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS - - stop_words = DEFAULT_STOP_WORDS - - import string - - words = [word.lower().strip(string.punctuation) for word in text.split()] - - if stop_words: - words = [word for word in words if word and word not in stop_words] - - from collections import Counter - - top_words = [word for word, freq in Counter(words).most_common(top_n)] - - return separator.join(top_words) - - """Creates a title, by combining first words with most frequent words from the text.""" - first_words = text.split()[:first_n_words] - top_words = _top_n_words(text, top_n=first_n_words) - return f"{' '.join(first_words)}... [{top_words}]" - - """Creates a dictionary of nodes with their names and content.""" - nodes = {} - for edge in retrieved_edges: - for node in (edge.node1, edge.node2): - if node.id not in nodes: - text = node.attributes.get("text") - if text: - name = _get_title(text) - content = text - else: - name = node.attributes.get("name", "Unnamed Node") - content = node.attributes.get("description", name) - nodes[node.id] = {"node": node, "name": name, "content": content} - return nodes - - nodes = _get_nodes(retrieved_edges) node_section = "\n".join( f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n" for info in nodes.values() ) - connection_section = "\n".join( - f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}" - for edge in retrieved_edges - ) + + connections = [] + for edge in retrieved_edges: + source_name = nodes[edge.node1.id]["name"] + target_name = nodes[edge.node2.id]["name"] + edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + connections.append(f"{source_name} --[{edge_label}]--> {target_name}") + + connection_section = "\n".join(connections) + return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" diff --git a/cognee/modules/ingestion/identify.py b/cognee/modules/ingestion/identify.py index 977ff3f0b..640fce4a2 100644 --- a/cognee/modules/ingestion/identify.py +++ b/cognee/modules/ingestion/identify.py @@ -1,11 +1,11 @@ -from uuid import uuid5, NAMESPACE_OID +from uuid import UUID from .data_types import IngestionData from cognee.modules.users.models import User +from cognee.modules.data.methods import get_unique_data_id -def identify(data: IngestionData, user: User) -> str: +async def identify(data: IngestionData, user: User) -> UUID: data_content_hash: str = data.get_identifier() - # return UUID hash of file contents + owner id - return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}") + return await get_unique_data_id(data_identifier=data_content_hash, user=user) diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py index 071deafb7..46499186e 100644 --- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py +++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py @@ -2,6 +2,8 @@ import io import sys import traceback +import cognee + def wrap_in_async_handler(user_code: str) -> str: return ( @@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None): environment["print"] = customPrintFunction environment["running_loop"] = loop + environment["cognee"] = cognee try: exec(code, environment) diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 45e32936a..34d7a946a 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -2,7 +2,7 @@ import os import difflib from cognee.shared.logging_utils import get_logger from collections import deque -from typing import List, Tuple, Dict, Optional, Any, Union +from typing import List, Tuple, Dict, Optional, Any, Union, IO from rdflib import Graph, URIRef, RDF, RDFS, OWL from cognee.modules.ontology.exceptions import ( @@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str]]] = None, + ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) self.ontology_file = ontology_file try: - files_to_load = [] + self.graph = None if ontology_file is not None: - if isinstance(ontology_file, str): + files_to_load = [] + file_objects = [] + + if hasattr(ontology_file, "read"): + file_objects = [ontology_file] + elif isinstance(ontology_file, str): files_to_load = [ontology_file] elif isinstance(ontology_file, list): - files_to_load = ontology_file + if all(hasattr(item, "read") for item in ontology_file): + file_objects = ontology_file + else: + files_to_load = ontology_file else: raise ValueError( - f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}" + f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}" ) - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) - else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, + if file_objects: + self.graph = Graph() + loaded_objects = [] + for file_obj in file_objects: + try: + content = file_obj.read() + self.graph.parse(data=content, format="xml") + loaded_objects.append(file_obj) + logger.info("Ontology loaded successfully from file object") + except Exception as e: + logger.warning("Failed to parse ontology file object: %s", str(e)) + + if not loaded_objects: + logger.info( + "No valid ontology file objects found. No owl ontology will be attached to the graph." ) + self.graph = None + else: + logger.info("Total ontology file objects loaded: %d", len(loaded_objects)) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None + elif files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." diff --git a/cognee/modules/pipelines/operations/run_tasks_data_item.py b/cognee/modules/pipelines/operations/run_tasks_data_item.py index 152e72d7f..2cc449df6 100644 --- a/cognee/modules/pipelines/operations/run_tasks_data_item.py +++ b/cognee/modules/pipelines/operations/run_tasks_data_item.py @@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental( async with open_data_file(file_path) as file: classified_data = ingestion.classify(file) # data_id is the hash of file contents + owner id to avoid duplicate data - data_id = ingestion.identify(classified_data, user) + data_id = await ingestion.identify(classified_data, user) else: # If data was already processed by Cognee get data id data_id = data_item.id diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 6086977ce..14996f902 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor @@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever): return None async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> List[str]: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generate completion using provided context or fetch new context. @@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever): fetched if not provided. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py index b0abc2991..b203309ba 100644 --- a/cognee/modules/retrieval/base_graph_retriever.py +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional, Type from abc import ABC, abstractmethod from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context (triplets).""" pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 1533dd44f..b88c741b8 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Type, List class BaseRetriever(ABC): @@ -12,7 +12,11 @@ class BaseRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> Any: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context.""" pass diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index bb568924d..126ebcab8 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional +from typing import Any, Optional, Type, List from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine @@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever): raise NoDataError("No data found in the system, please add data first.") from error async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generates an LLM completion using the context. @@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever): completion; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if session_save: @@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever): session_id=session_id, ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 58b6b586f..fc49a139b 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, List, Type +from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) async def get_completion( @@ -56,7 +60,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context: Optional[List[Edge]] = None, session_id: Optional[str] = None, context_extension_rounds=4, - ) -> List[str]: + response_model: Type = str, + ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new completions based on them. @@ -76,6 +81,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): defaults to 'default_session'. (default None) - context_extension_rounds: The maximum number of rounds to extend the context with new triplets before halting. (default 4) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -143,6 +149,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -152,6 +159,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 299db6855..70fcb6cdb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import ( - generate_structured_completion, + generate_completion, summarize_text, ) from cognee.modules.retrieval.utils.session_cache import ( @@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): questions based on reasoning. The public methods are: - get_completion - - get_structured_completion Instance variables include: - validation_system_prompt_path @@ -66,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -75,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path @@ -121,7 +124,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_context(followup_question) context_text = await self.resolve_edges_to_text(list(set(triplets))) - completion = await generate_structured_completion( + completion = await generate_completion( query=query, context=context_text, user_prompt_path=self.user_prompt_path, @@ -165,24 +168,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): return completion, context_text, triplets - async def get_structured_completion( + async def get_completion( self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - max_iter: int = 4, + max_iter=4, response_model: Type = str, - ) -> Any: + ) -> List[Any]: """ - Generate structured completion responses based on a user query and contextual information. + Generate completion responses based on a user query and contextual information. - This method applies the same chain-of-thought logic as get_completion but returns + This method interacts with a language model client to retrieve a structured response, + using a series of iterations to refine the answers and generate follow-up questions + based on reasoning derived from previous outputs. It raises exceptions if the context + retrieval fails or if the model encounters issues in generating outputs. It returns structured output using the provided response model. Parameters: ----------- + - query (str): The user's query to be processed and answered. - - context (Optional[List[Edge]]): Optional context that may assist in answering the query. + - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) @@ -192,7 +199,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): Returns: -------- - - Any: The generated structured completion based on the response model. + + - List[str]: A list containing the generated answer to the user's query. """ # Check if session saving is enabled cache_config = CacheConfig() @@ -228,45 +236,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion - - async def get_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - session_id: Optional[str] = None, - max_iter=4, - ) -> List[str]: - """ - Generate completion responses based on a user query and contextual information. - - This method interacts with a language model client to retrieve a structured response, - using a series of iterations to refine the answers and generate follow-up questions - based on reasoning derived from previous outputs. It raises exceptions if the context - retrieval fails or if the model encounters issues in generating outputs. - - Parameters: - ----------- - - - query (str): The user's query to be processed and answered. - - context (Optional[Any]): Optional context that may assist in answering the query. - If not provided, it will be fetched based on the query. (default None) - - session_id (Optional[str]): Optional session identifier for caching. If None, - defaults to 'default_session'. (default None) - - max_iter: The maximum number of iterations to refine the answer and generate - follow-up questions. (default 4) - - Returns: - -------- - - - List[str]: A list containing the generated answer to the user's query. - """ - completion = await self.get_structured_completion( - query=query, - context=context, - session_id=session_id, - max_iter=max_iter, - response_model=str, - ) - return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b7ab4edae..89e9e47ce 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction @@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): self.system_prompt_path = system_prompt_path self.system_prompt = system_prompt self.top_k = top_k if top_k is not None else 5 + self.wide_search_top_k = wide_search_top_k self.node_type = node_type self.node_name = node_name + self.triplet_distance_penalty = triplet_distance_penalty async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """ @@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): collections=vector_index_collections or None, node_type=self.node_type, node_name=self.node_name, + wide_search_top_k=self.wide_search_top_k, + triplet_distance_penalty=self.triplet_distance_penalty, ) return found_triplets @@ -141,12 +147,17 @@ class GraphCompletionRetriever(BaseGraphRetriever): return triplets + async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): + context = await self.resolve_edges_to_text(triplets) + return context + async def get_completion( self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - ) -> List[str]: + response_model: Type = str, + ) -> List[Any]: """ Generates a completion using graph connections context based on a query. @@ -188,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -197,6 +209,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 051f39b22..e31ad126e 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index ec68d37bb..87d2ab009 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path @@ -146,8 +150,12 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) async def get_completion( - self, query: str, context: Optional[str] = None, session_id: Optional[str] = None - ) -> List[str]: + self, + query: str, + context: Optional[str] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """ Generates a response using the query and optional context. @@ -159,6 +167,7 @@ class TemporalRetriever(GraphCompletionRetriever): retrieved based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -186,6 +195,7 @@ class TemporalRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -194,6 +204,7 @@ class TemporalRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 1ef7545c2..2f8a545f7 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -58,6 +58,8 @@ async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: Optional[float] = 3.5, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" if properties_to_project is None: @@ -71,9 +73,11 @@ async def get_memory_fragment( await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, - edge_properties_to_project=["relationship_name"], + edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, ) except EntityNotFoundError: @@ -95,6 +99,8 @@ async def brute_force_triplet_search( memory_fragment: Optional[CogneeGraph] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Edge]: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -107,6 +113,8 @@ async def brute_force_triplet_search( memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: list: The top triplet results. @@ -116,10 +124,10 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project, node_type=node_type, node_name=node_name - ) + # Setting wide search limit based on the parameters + non_global_search = node_name is None + + wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: collections = [ @@ -140,7 +148,7 @@ async def brute_force_triplet_search( async def search_in_collection(collection_name: str): try: return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit ) except CollectionNotFoundError: return [] @@ -156,15 +164,38 @@ async def brute_force_triplet_search( return [] # Final statistics - projection_time = time.time() - start_time + vector_collection_search_time = time.time() - start_time logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s" + f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" ) node_distances = {collection: result for collection, result in zip(collections, results)} edge_distances = node_distances.get("EdgeType_relationship_name", None) + if wide_search_limit is not None: + relevant_ids_to_filter = list( + { + str(getattr(scored_node, "id")) + for collection_name, score_collection in node_distances.items() + if collection_name != "EdgeType_relationship_name" + and isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + ) + else: + relevant_ids_to_filter = None + + if memory_fragment is None: + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, + ) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) await memory_fragment.map_vector_distances_to_graph_edges( vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index db7a10252..c90ce77f4 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt -async def generate_structured_completion( +async def generate_completion( query: str, context: str, user_prompt_path: str, @@ -12,7 +12,7 @@ async def generate_structured_completion( conversation_history: Optional[str] = None, response_model: Type = str, ) -> Any: - """Generates a structured completion using LLM with given context and prompts.""" + """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) @@ -28,26 +28,6 @@ async def generate_structured_completion( ) -async def generate_completion( - query: str, - context: str, - user_prompt_path: str, - system_prompt_path: str, - system_prompt: Optional[str] = None, - conversation_history: Optional[str] = None, -) -> str: - """Generates a completion using LLM with given context and prompts.""" - return await generate_structured_completion( - query=query, - context=context, - user_prompt_path=user_prompt_path, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - conversation_history=conversation_history, - response_model=str, - ) - - async def summarize_text( text: str, system_prompt_path: str = "summarize_search_results.txt", diff --git a/cognee/modules/run_custom_pipeline/__init__.py b/cognee/modules/run_custom_pipeline/__init__.py new file mode 100644 index 000000000..2d30e2e0c --- /dev/null +++ b/cognee/modules/run_custom_pipeline/__init__.py @@ -0,0 +1 @@ +from .run_custom_pipeline import run_custom_pipeline diff --git a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py new file mode 100644 index 000000000..d3df1c060 --- /dev/null +++ b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py @@ -0,0 +1,69 @@ +from typing import Union, Optional, List, Type, Any +from uuid import UUID + +from cognee.shared.logging_utils import get_logger + +from cognee.modules.pipelines import run_pipeline +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.users.models import User +from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor + +logger = get_logger() + + +async def run_custom_pipeline( + tasks: Union[List[Task], List[str]] = None, + data: Any = None, + dataset: Union[str, UUID] = "main_dataset", + user: User = None, + vector_db_config: Optional[dict] = None, + graph_db_config: Optional[dict] = None, + data_per_batch: int = 20, + run_in_background: bool = False, + pipeline_name: str = "custom_pipeline", +): + """ + Custom pipeline in Cognee, can work with already built graphs. Data needs to be provided which can be processed + with provided tasks. + + Provided tasks and data will be arranged to run the Cognee pipeline and execute graph enrichment/creation. + + This is the core processing step in Cognee that converts raw text and documents + into an intelligent knowledge graph. It analyzes content, extracts entities and + relationships, and creates semantic connections for enhanced search and reasoning. + + Args: + tasks: List of Cognee Tasks to execute. + data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used. + Data provided here will be forwarded to the first extraction task in the pipeline as input. + dataset: Dataset name or dataset uuid to process. + user: User context for authentication and data access. Uses default if None. + vector_db_config: Custom vector database configuration for embeddings storage. + graph_db_config: Custom graph database configuration for relationship storage. + data_per_batch: Number of data items to be processed in parallel. + run_in_background: If True, starts processing asynchronously and returns immediately. + If False, waits for completion before returning. + Background mode recommended for large datasets (>100MB). + Use pipeline_run_id from return value to monitor progress. + """ + + custom_tasks = [ + *tasks, + ] + + # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for + pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) + + # Run the run_pipeline in the background or blocking based on executor + return await pipeline_executor_func( + pipeline=run_pipeline, + tasks=custom_tasks, + user=user, + data=data, + datasets=dataset, + vector_db_config=vector_db_config, + graph_db_config=graph_db_config, + incremental_loading=False, + data_per_batch=data_per_batch, + pipeline_name=pipeline_name, + ) diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index 72e2db89a..165ec379b 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -37,6 +37,8 @@ async def get_search_type_tools( node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> list: search_tasks: dict[SearchType, List[Callable]] = { SearchType.SUMMARIES: [ @@ -67,6 +69,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -75,6 +79,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_COT: [ @@ -85,6 +91,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -93,6 +101,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [ @@ -103,6 +113,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -111,6 +123,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_SUMMARY_COMPLETION: [ @@ -121,6 +135,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -129,6 +145,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.CODE: [ @@ -145,8 +163,16 @@ async def get_search_type_tools( ], SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback], SearchType.TEMPORAL: [ - TemporalRetriever(top_k=top_k).get_completion, - TemporalRetriever(top_k=top_k).get_context, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_completion, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_context, ], SearchType.CHUNKS_LEXICAL: ( lambda _r=JaccardChunksRetriever(top_k=top_k): [ diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index fcb02da46..3a703bbc9 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -24,6 +24,8 @@ async def no_access_control_search( last_k: Optional[int] = None, only_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: search_tools = await get_search_type_tools( query_type=query_type, @@ -35,6 +37,8 @@ async def no_access_control_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) graph_engine = await get_graph_engine() is_empty = await graph_engine.is_empty() diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index aab004924..9f180d607 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -1,4 +1,3 @@ -import os import json import asyncio from uuid import UUID @@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.shared.logging_utils import get_logger from cognee.shared.utils import send_telemetry from cognee.context_global_variables import set_database_global_context_variables +from cognee.context_global_variables import backend_access_control_enabled from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -47,6 +47,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -74,7 +76,7 @@ async def search( ) # Use search function filtered by permissions if access control is enabled - if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if backend_access_control_enabled(): search_results = await authorized_search( query_type=query_type, query_text=query_text, @@ -90,6 +92,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) else: search_results = [ @@ -105,6 +109,8 @@ async def search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ] @@ -156,7 +162,7 @@ async def search( ) else: # This is for maintaining backwards compatibility - if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if backend_access_control_enabled(): return_value = [] for search_result in search_results: prepared_search_results = await prepare_search_result(search_result) @@ -172,6 +178,7 @@ async def search( "search_result": [context] if context else None, "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, "graphs": graphs, } ) @@ -181,6 +188,7 @@ async def search( "search_result": [result] if result else None, "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, "graphs": graphs, } ) @@ -217,6 +225,8 @@ async def authorized_search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -244,6 +254,8 @@ async def authorized_search( last_k=last_k, only_context=True, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) context = {} @@ -265,6 +277,8 @@ async def authorized_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -304,6 +318,7 @@ async def authorized_search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, ) return search_results @@ -323,6 +338,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -343,6 +360,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -376,6 +395,8 @@ async def search_in_datasets_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -411,6 +432,8 @@ async def search_in_datasets_context( only_context=only_context, context=context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ) diff --git a/cognee/modules/users/methods/create_user.py b/cognee/modules/users/methods/create_user.py index 1b303bd36..ef325fb6f 100644 --- a/cognee/modules/users/methods/create_user.py +++ b/cognee/modules/users/methods/create_user.py @@ -18,7 +18,6 @@ from typing import Optional async def create_user( email: str, password: str, - tenant_id: Optional[str] = None, is_superuser: bool = False, is_active: bool = True, is_verified: bool = False, @@ -30,37 +29,23 @@ async def create_user( async with relational_engine.get_async_session() as session: async with get_user_db_context(session) as user_db: async with get_user_manager_context(user_db) as user_manager: - if tenant_id: - # Check if the tenant already exists - result = await session.execute(select(Tenant).where(Tenant.id == tenant_id)) - tenant = result.scalars().first() - if not tenant: - raise TenantNotFoundError - - user = await user_manager.create( - UserCreate( - email=email, - password=password, - tenant_id=tenant.id, - is_superuser=is_superuser, - is_active=is_active, - is_verified=is_verified, - ) - ) - else: - user = await user_manager.create( - UserCreate( - email=email, - password=password, - is_superuser=is_superuser, - is_active=is_active, - is_verified=is_verified, - ) + user = await user_manager.create( + UserCreate( + email=email, + password=password, + is_superuser=is_superuser, + is_active=is_active, + is_verified=is_verified, ) + ) if auto_login: await session.refresh(user) + # Update tenants and roles information for User object + _ = await user.awaitable_attrs.tenants + _ = await user.awaitable_attrs.roles + return user except UserAlreadyExists as error: print(f"User {email} already exists") diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index d78215892..d6d701737 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -5,6 +5,7 @@ from ..models import User from ..get_fastapi_users import get_fastapi_users from .get_default_user import get_default_user from cognee.shared.logging_utils import get_logger +from cognee.context_global_variables import backend_access_control_enabled logger = get_logger("get_authenticated_user") @@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user") # Check environment variable to determine authentication requirement REQUIRE_AUTHENTICATION = ( os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" - or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true" + or backend_access_control_enabled() ) fastapi_users = get_fastapi_users() diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index 773545f8e..8dc364f32 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -10,7 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods.create_default_user import create_default_user -async def get_default_user() -> SimpleNamespace: +async def get_default_user() -> User: db_engine = get_relational_engine() base_config = get_base_config() default_email = base_config.default_user_email or "default_user@example.com" @@ -18,7 +18,9 @@ async def get_default_user() -> SimpleNamespace: try: async with db_engine.get_async_session() as session: query = ( - select(User).options(selectinload(User.roles)).where(User.email == default_email) + select(User) + .options(selectinload(User.roles), selectinload(User.tenants)) + .where(User.email == default_email) ) result = await session.execute(query) diff --git a/cognee/modules/users/methods/get_user.py b/cognee/modules/users/methods/get_user.py index 2678a5a01..a1c87aab7 100644 --- a/cognee/modules/users/methods/get_user.py +++ b/cognee/modules/users/methods/get_user.py @@ -14,7 +14,7 @@ async def get_user(user_id: UUID): user = ( await session.execute( select(User) - .options(selectinload(User.roles), selectinload(User.tenant)) + .options(selectinload(User.roles), selectinload(User.tenants)) .where(User.id == user_id) ) ).scalar() diff --git a/cognee/modules/users/methods/get_user_by_email.py b/cognee/modules/users/methods/get_user_by_email.py index c4bd5b48e..6df989251 100644 --- a/cognee/modules/users/methods/get_user_by_email.py +++ b/cognee/modules/users/methods/get_user_by_email.py @@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str): user = ( await session.execute( select(User) - .options(joinedload(User.roles), joinedload(User.tenant)) + .options(joinedload(User.roles), joinedload(User.tenants)) .where(User.email == user_email) ) ).scalar() diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 0d71d8413..25d610ab9 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -15,5 +15,14 @@ class DatasetDatabase(Base): vector_database_name = Column(String, unique=True, nullable=False) graph_database_name = Column(String, unique=True, nullable=False) + vector_database_provider = Column(String, unique=False, nullable=False) + graph_database_provider = Column(String, unique=False, nullable=False) + + vector_database_url = Column(String, unique=False, nullable=True) + graph_database_url = Column(String, unique=False, nullable=True) + + vector_database_key = Column(String, unique=False, nullable=True) + graph_database_key = Column(String, unique=False, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/cognee/modules/users/models/Tenant.py b/cognee/modules/users/models/Tenant.py index 95023a6ee..b8fa158c5 100644 --- a/cognee/modules/users/models/Tenant.py +++ b/cognee/modules/users/models/Tenant.py @@ -1,7 +1,7 @@ -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, Mapped from sqlalchemy import Column, String, ForeignKey, UUID from .Principal import Principal -from .User import User +from .UserTenant import UserTenant from .Role import Role @@ -13,14 +13,13 @@ class Tenant(Principal): owner_id = Column(UUID, index=True) - # One-to-Many relationship with User; specify the join via User.tenant_id - users = relationship( + users: Mapped[list["User"]] = relationship( # noqa: F821 "User", - back_populates="tenant", - foreign_keys=lambda: [User.tenant_id], + secondary=UserTenant.__tablename__, + back_populates="tenants", ) - # One-to-Many relationship with Role (if needed; similar fix) + # One-to-Many relationship with Role roles = relationship( "Role", back_populates="tenant", diff --git a/cognee/modules/users/models/User.py b/cognee/modules/users/models/User.py index 8972a5932..a98abd3bc 100644 --- a/cognee/modules/users/models/User.py +++ b/cognee/modules/users/models/User.py @@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID from sqlalchemy.orm import relationship, Mapped from .Principal import Principal +from .UserTenant import UserTenant from .UserRole import UserRole from .Role import Role +from .Tenant import Tenant class User(SQLAlchemyBaseUserTableUUID, Principal): @@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal): id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True) - # Foreign key to Tenant (Many-to-One relationship) + # Foreign key to current Tenant (Many-to-One relationship) tenant_id = Column(UUID, ForeignKey("tenants.id")) # Many-to-Many Relationship with Roles @@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal): back_populates="users", ) - # Relationship to Tenant - tenant = relationship( + # Many-to-Many Relationship with Tenants user is a part of + tenants: Mapped[list["Tenant"]] = relationship( "Tenant", + secondary=UserTenant.__tablename__, back_populates="users", - foreign_keys=[tenant_id], ) # ACL Relationship (One-to-Many) @@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]): class UserCreate(schemas.BaseUserCreate): - tenant_id: Optional[uuid_UUID] = None is_verified: bool = True diff --git a/cognee/modules/users/models/UserTenant.py b/cognee/modules/users/models/UserTenant.py new file mode 100644 index 000000000..bfb852aa5 --- /dev/null +++ b/cognee/modules/users/models/UserTenant.py @@ -0,0 +1,12 @@ +from datetime import datetime, timezone +from sqlalchemy import Column, ForeignKey, DateTime, UUID +from cognee.infrastructure.databases.relational import Base + + +class UserTenant(Base): + __tablename__ = "user_tenants" + + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + + user_id = Column(UUID, ForeignKey("users.id"), primary_key=True) + tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True) diff --git a/cognee/modules/users/models/__init__.py b/cognee/modules/users/models/__init__.py index ba2f40e49..5114cc45a 100644 --- a/cognee/modules/users/models/__init__.py +++ b/cognee/modules/users/models/__init__.py @@ -1,6 +1,7 @@ from .User import User from .Role import Role from .UserRole import UserRole +from .UserTenant import UserTenant from .DatasetDatabase import DatasetDatabase from .RoleDefaultPermissions import RoleDefaultPermissions from .UserDefaultPermissions import UserDefaultPermissions diff --git a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py index 1185dd7ad..5eed992db 100644 --- a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +++ b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py @@ -1,11 +1,8 @@ -from types import SimpleNamespace - from cognee.shared.logging_utils import get_logger from ...models.User import User from cognee.modules.data.models.Dataset import Dataset from cognee.modules.users.permissions.methods import get_principal_datasets -from cognee.modules.users.permissions.methods import get_role, get_tenant logger = get_logger() @@ -25,17 +22,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) -> # Get all datasets User has explicit access to datasets.extend(await get_principal_datasets(user, permission_type)) - if user.tenant_id: - # Get all datasets all tenants have access to - tenant = await get_tenant(user.tenant_id) + # Get all tenants user is a part of + tenants = await user.awaitable_attrs.tenants + for tenant in tenants: + # Get all datasets all tenant members have access to datasets.extend(await get_principal_datasets(tenant, permission_type)) - # Get all datasets Users roles have access to - if isinstance(user, SimpleNamespace): - # If simple namespace use roles defined in user - roles = user.roles - else: - roles = await user.awaitable_attrs.roles + # Get all datasets accessible by roles user is a part of + roles = await user.awaitable_attrs.roles for role in roles: datasets.extend(await get_principal_datasets(role, permission_type)) @@ -45,4 +39,10 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) -> # If the dataset id key already exists, leave the dictionary unchanged. unique.setdefault(dataset.id, dataset) - return list(unique.values()) + # Filter out dataset that aren't part of the selected user's tenant + filtered_datasets = [] + for dataset in list(unique.values()): + if dataset.tenant_id == user.tenant_id: + filtered_datasets.append(dataset) + + return filtered_datasets diff --git a/cognee/modules/users/roles/methods/add_user_to_role.py b/cognee/modules/users/roles/methods/add_user_to_role.py index de5e47775..23bb947f0 100644 --- a/cognee/modules/users/roles/methods/add_user_to_role.py +++ b/cognee/modules/users/roles/methods/add_user_to_role.py @@ -42,11 +42,13 @@ async def add_user_to_role(user_id: UUID, role_id: UUID, owner_id: UUID): .first() ) + user_tenants = await user.awaitable_attrs.tenants + if not user: raise UserNotFoundError elif not role: raise RoleNotFoundError - elif user.tenant_id != role.tenant_id: + elif role.tenant_id not in [tenant.id for tenant in user_tenants]: raise TenantNotFoundError( message="User tenant does not match role tenant. User cannot be added to role." ) diff --git a/cognee/modules/users/tenants/methods/__init__.py b/cognee/modules/users/tenants/methods/__init__.py index 9a052e9c6..39e2b31bb 100644 --- a/cognee/modules/users/tenants/methods/__init__.py +++ b/cognee/modules/users/tenants/methods/__init__.py @@ -1,2 +1,3 @@ from .create_tenant import create_tenant from .add_user_to_tenant import add_user_to_tenant +from .select_tenant import select_tenant diff --git a/cognee/modules/users/tenants/methods/add_user_to_tenant.py b/cognee/modules/users/tenants/methods/add_user_to_tenant.py index 1374067a7..eecc49f6f 100644 --- a/cognee/modules/users/tenants/methods/add_user_to_tenant.py +++ b/cognee/modules/users/tenants/methods/add_user_to_tenant.py @@ -1,8 +1,11 @@ +from typing import Optional from uuid import UUID from sqlalchemy.exc import IntegrityError +from sqlalchemy import insert from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.models.UserTenant import UserTenant from cognee.modules.users.methods import get_user from cognee.modules.users.permissions.methods import get_tenant from cognee.modules.users.exceptions import ( @@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import ( ) -async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID): +async def add_user_to_tenant( + user_id: UUID, tenant_id: UUID, owner_id: UUID, set_as_active_tenant: Optional[bool] = False +): """ Add a user with the given id to the tenant with the given id. This can only be successful if the request owner with the given id is the tenant owner. + + If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant. Args: user_id: Id of the user. tenant_id: Id of the tenant. owner_id: Id of the request owner. + set_as_active_tenant: If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant. Returns: None @@ -40,17 +48,18 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID): message="Only tenant owner can add other users to organization." ) - try: - if user.tenant_id is None: - user.tenant_id = tenant_id - elif user.tenant_id == tenant_id: - return - else: - raise IntegrityError - + if set_as_active_tenant: + user.tenant_id = tenant_id await session.merge(user) await session.commit() - except IntegrityError: - raise EntityAlreadyExistsError( - message="User is already part of a tenant. Only one tenant can be assigned to user." + + try: + # Add association directly to the association table + create_user_tenant_statement = insert(UserTenant).values( + user_id=user_id, tenant_id=tenant_id ) + await session.execute(create_user_tenant_statement) + await session.commit() + + except IntegrityError: + raise EntityAlreadyExistsError(message="User is already part of group.") diff --git a/cognee/modules/users/tenants/methods/create_tenant.py b/cognee/modules/users/tenants/methods/create_tenant.py index bfd23e08f..32baa05fd 100644 --- a/cognee/modules/users/tenants/methods/create_tenant.py +++ b/cognee/modules/users/tenants/methods/create_tenant.py @@ -1,19 +1,25 @@ from uuid import UUID +from sqlalchemy import insert from sqlalchemy.exc import IntegrityError +from typing import Optional +from cognee.modules.users.models.UserTenant import UserTenant from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.models import Tenant from cognee.modules.users.methods import get_user -async def create_tenant(tenant_name: str, user_id: UUID) -> UUID: +async def create_tenant( + tenant_name: str, user_id: UUID, set_as_active_tenant: Optional[bool] = True +) -> UUID: """ Create a new tenant with the given name, for the user with the given id. This user is the owner of the tenant. Args: tenant_name: Name of the new tenant. user_id: Id of the user. + set_as_active_tenant: If true, set the newly created tenant as the active tenant for the user. Returns: None @@ -22,18 +28,26 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID: async with db_engine.get_async_session() as session: try: user = await get_user(user_id) - if user.tenant_id: - raise EntityAlreadyExistsError( - message="User already has a tenant. New tenant cannot be created." - ) tenant = Tenant(name=tenant_name, owner_id=user_id) session.add(tenant) await session.flush() - user.tenant_id = tenant.id - await session.merge(user) - await session.commit() + if set_as_active_tenant: + user.tenant_id = tenant.id + await session.merge(user) + await session.commit() + + try: + # Add association directly to the association table + create_user_tenant_statement = insert(UserTenant).values( + user_id=user_id, tenant_id=tenant.id + ) + await session.execute(create_user_tenant_statement) + await session.commit() + except IntegrityError: + raise EntityAlreadyExistsError(message="User is already part of tenant.") + return tenant.id except IntegrityError as e: raise EntityAlreadyExistsError(message="Tenant already exists.") from e diff --git a/cognee/modules/users/tenants/methods/select_tenant.py b/cognee/modules/users/tenants/methods/select_tenant.py new file mode 100644 index 000000000..83c11dc91 --- /dev/null +++ b/cognee/modules/users/tenants/methods/select_tenant.py @@ -0,0 +1,62 @@ +from uuid import UUID +from typing import Union + +import sqlalchemy.exc +from sqlalchemy import select + +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.methods.get_user import get_user +from cognee.modules.users.models.UserTenant import UserTenant +from cognee.modules.users.models.User import User +from cognee.modules.users.permissions.methods import get_tenant +from cognee.modules.users.exceptions import UserNotFoundError, TenantNotFoundError + + +async def select_tenant(user_id: UUID, tenant_id: Union[UUID, None]) -> User: + """ + Set the users active tenant to provided tenant. + + If None tenant_id is provided set current Tenant to the default single user-tenant + Args: + user_id: UUID of the user. + tenant_id: Id of the tenant. + + Returns: + None + + """ + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + user = await get_user(user_id) + if tenant_id is None: + # If no tenant_id is provided set current Tenant to the single user-tenant + user.tenant_id = None + await session.merge(user) + await session.commit() + return user + + tenant = await get_tenant(tenant_id) + + if not user: + raise UserNotFoundError + elif not tenant: + raise TenantNotFoundError + + # Check if User is part of Tenant + result = await session.execute( + select(UserTenant) + .where(UserTenant.user_id == user.id) + .where(UserTenant.tenant_id == tenant_id) + ) + + try: + result = result.scalar_one() + except sqlalchemy.exc.NoResultFound as e: + raise TenantNotFoundError("User is not part of the tenant.") from e + + if result: + # If user is part of tenant update current tenant of user + user.tenant_id = tenant_id + await session.merge(user) + await session.commit() + return user diff --git a/cognee/shared/logging_utils.py b/cognee/shared/logging_utils.py index 0e5120b1d..e8efde72c 100644 --- a/cognee/shared/logging_utils.py +++ b/cognee/shared/logging_utils.py @@ -450,6 +450,8 @@ def setup_logging(log_level=None, name=None): try: msg = self.format(record) stream = self.stream + if hasattr(stream, "closed") and stream.closed: + return stream.write("\n" + msg + self.terminator) self.flush() except Exception: diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index 22ce96be8..37d4de73e 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,5 @@ from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph +from .chunk_by_row import chunk_by_row from .remove_disconnected_chunks import remove_disconnected_chunks diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py new file mode 100644 index 000000000..8daf13689 --- /dev/null +++ b/cognee/tasks/chunks/chunk_by_row.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, Iterator +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine + + +def _get_pair_size(pair_text: str) -> int: + """ + Calculate the size of a given text in terms of tokens. + + If an embedding engine's tokenizer is available, count the tokens for the provided word. + If the tokenizer is not available, assume the word counts as one token. + + Parameters: + ----------- + + - pair_text (str): The key:value pair text for which the token size is to be calculated. + + Returns: + -------- + + - int: The number of tokens representing the text, typically an integer, depending + on the tokenizer's output. + """ + embedding_engine = get_embedding_engine() + if embedding_engine.tokenizer: + return embedding_engine.tokenizer.count_tokens(pair_text) + else: + return 3 + + +def chunk_by_row( + data: str, + max_chunk_size, +) -> Iterator[Dict[str, Any]]: + """ + Chunk the input text by row while enabling exact text reconstruction. + + This function divides the given text data into smaller chunks on a line-by-line basis, + ensuring that the size of each chunk is less than or equal to the specified maximum + chunk size. It guarantees that when the generated chunks are concatenated, they + reproduce the original text accurately. The tokenization process is handled by + adapters compatible with the vector engine's embedding model. + + Parameters: + ----------- + + - data (str): The input text to be chunked. + - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or + words. + """ + current_chunk_list = [] + chunk_index = 0 + current_chunk_size = 0 + + lines = data.split("\n\n") + for line in lines: + pairs_text = line.split(", ") + + for pair_text in pairs_text: + pair_size = _get_pair_size(pair_text) + if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size): + # Yield current cut chunk + current_chunk = ", ".join(current_chunk_list) + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_cut", + } + + yield chunk_dict + + # Start new chunk with current pair text + current_chunk_list = [] + current_chunk_size = 0 + chunk_index += 1 + + current_chunk_list.append(pair_text) + current_chunk_size += pair_size + + # Yield row chunk + current_chunk = ", ".join(current_chunk_list) + if current_chunk: + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_end", + } + + yield chunk_dict diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9fa512906..e4f13ebd1 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import ( ImageDocument, TextDocument, UnstructuredDocument, + CsvDocument, ) from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.utils.generate_node_id import generate_node_id @@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError EXTENSION_TO_DOCUMENT_CLASS = { "pdf": PdfDocument, # Text documents "txt": TextDocument, + "csv": CsvDocument, "docx": UnstructuredDocument, "doc": UnstructuredDocument, "odt": UnstructuredDocument, diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index e439cf9e5..d2b143d29 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction( ) retrieved_context = await retriever.get_context(query_text) - completion = await retriever.get_structured_completion( + completion = await retriever.get_completion( query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse, @@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction( new_context_text = await retriever.resolve_edges_to_text(retrieved_context) if completion: - enrichment.improved_answer = completion.answer + enrichment.improved_answer = completion[0].answer enrichment.new_context = new_context_text - enrichment.explanation = completion.explanation + enrichment.explanation = completion[0].explanation return enrichment else: logger.warning( diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 0572d0f1e..5987f38d5 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -99,7 +99,7 @@ async def ingest_data( # data_id is the hash of original file contents + owner id to avoid duplicate data - data_id = ingestion.identify(classified_data, user) + data_id = await ingestion.identify(classified_data, user) original_file_metadata = classified_data.get_metadata() # Find metadata from Cognee data storage text file diff --git a/cognee/tasks/memify/__init__.py b/cognee/tasks/memify/__init__.py index 692bac443..7e590ed47 100644 --- a/cognee/tasks/memify/__init__.py +++ b/cognee/tasks/memify/__init__.py @@ -1,2 +1,4 @@ from .extract_subgraph import extract_subgraph from .extract_subgraph_chunks import extract_subgraph_chunks +from .cognify_session import cognify_session +from .extract_user_sessions import extract_user_sessions diff --git a/cognee/tasks/memify/cognify_session.py b/cognee/tasks/memify/cognify_session.py new file mode 100644 index 000000000..f53f9afb1 --- /dev/null +++ b/cognee/tasks/memify/cognify_session.py @@ -0,0 +1,41 @@ +import cognee + +from cognee.exceptions import CogneeValidationError, CogneeSystemError +from cognee.shared.logging_utils import get_logger + +logger = get_logger("cognify_session") + + +async def cognify_session(data, dataset_id=None): + """ + Process and cognify session data into the knowledge graph. + + Adds session content to cognee with a dedicated "user_sessions" node set, + then triggers the cognify pipeline to extract entities and relationships + from the session data. + + Args: + data: Session string containing Question, Context, and Answer information. + dataset_name: Name of dataset. + + Raises: + CogneeValidationError: If data is None or empty. + CogneeSystemError: If cognee operations fail. + """ + try: + if not data or (isinstance(data, str) and not data.strip()): + logger.warning("Empty session data provided to cognify_session task, skipping") + raise CogneeValidationError(message="Session data cannot be empty", log=False) + + logger.info("Processing session data for cognification") + + await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"]) + logger.debug("Session data added to cognee with node_set: user_sessions") + await cognee.cognify(datasets=[dataset_id]) + logger.info("Session data successfully cognified") + + except CogneeValidationError: + raise + except Exception as e: + logger.error(f"Error cognifying session data: {str(e)}") + raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False) diff --git a/cognee/tasks/memify/extract_user_sessions.py b/cognee/tasks/memify/extract_user_sessions.py new file mode 100644 index 000000000..9779a363e --- /dev/null +++ b/cognee/tasks/memify/extract_user_sessions.py @@ -0,0 +1,73 @@ +from typing import Optional, List + +from cognee.context_global_variables import session_user +from cognee.exceptions import CogneeSystemError +from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine +from cognee.shared.logging_utils import get_logger +from cognee.modules.users.models import User + +logger = get_logger("extract_user_sessions") + + +async def extract_user_sessions( + data, + session_ids: Optional[List[str]] = None, +): + """ + Extract Q&A sessions for the current user from cache. + + Retrieves all Q&A triplets from specified session IDs and yields them + as formatted strings combining question, context, and answer. + + Args: + data: Data passed from memify. If empty dict ({}), no external data is provided. + session_ids: Optional list of specific session IDs to extract. + + Yields: + String containing session ID and all Q&A pairs formatted. + + Raises: + CogneeSystemError: If cache engine is unavailable or extraction fails. + """ + try: + if not data or data == [{}]: + logger.info("Fetching session metadata for current user") + + user: User = session_user.get() + if not user: + raise CogneeSystemError(message="No authenticated user found in context", log=False) + + user_id = str(user.id) + + cache_engine = get_cache_engine() + if cache_engine is None: + raise CogneeSystemError( + message="Cache engine not available for session extraction, please enable caching in order to have sessions to save", + log=False, + ) + + if session_ids: + for session_id in session_ids: + try: + qa_data = await cache_engine.get_all_qas(user_id, session_id) + if qa_data: + logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs") + session_string = f"Session ID: {session_id}\n\n" + for qa_pair in qa_data: + question = qa_pair.get("question", "") + answer = qa_pair.get("answer", "") + session_string += f"Question: {question}\n\nAnswer: {answer}\n\n" + yield session_string + except Exception as e: + logger.warning(f"Failed to extract session {session_id}: {str(e)}") + continue + else: + logger.info( + "No specific session_ids provided. Please specify which sessions to extract." + ) + + except CogneeSystemError: + raise + except Exception as e: + logger.error(f"Error extracting user sessions: {str(e)}") + raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False) diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 902789c80..b0ec3a5b4 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -8,47 +8,58 @@ logger = get_logger("index_data_points") async def index_data_points(data_points: list[DataPoint]): - created_indexes = {} - index_points = {} + """Index data points in the vector engine by creating embeddings for specified fields. + + Process: + 1. Groups data points into a nested dict: {type_name: {field_name: [points]}} + 2. Creates vector indexes for each (type, field) combination on first encounter + 3. Batches points per (type, field) and creates async indexing tasks + 4. Executes all indexing tasks in parallel for efficient embedding generation + + Args: + data_points: List of DataPoint objects to index. Each DataPoint's metadata must + contain an 'index_fields' list specifying which fields to embed. + + Returns: + The original data_points list. + """ + data_points_by_type = {} vector_engine = get_vector_engine() for data_point in data_points: data_point_type = type(data_point) + type_name = data_point_type.__name__ for field_name in data_point.metadata["index_fields"]: if getattr(data_point, field_name, None) is None: continue - index_name = f"{data_point_type.__name__}_{field_name}" + if type_name not in data_points_by_type: + data_points_by_type[type_name] = {} - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] + if field_name not in data_points_by_type[type_name]: + await vector_engine.create_vector_index(type_name, field_name) + data_points_by_type[type_name][field_name] = [] indexed_data_point = data_point.model_copy() indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) + data_points_by_type[type_name][field_name].append(indexed_data_point) - tasks: list[asyncio.Task] = [] batch_size = vector_engine.embedding_engine.get_batch_size() - for index_name_and_field, points in index_points.items(): - first = index_name_and_field.index("_") - index_name = index_name_and_field[:first] - field_name = index_name_and_field[first + 1 :] + batches = ( + (type_name, field_name, points[i : i + batch_size]) + for type_name, fields in data_points_by_type.items() + for field_name, points in fields.items() + for i in range(0, len(points), batch_size) + ) - # Create embedding requests per batch to run in parallel later - for i in range(0, len(points), batch_size): - batch = points[i : i + batch_size] - tasks.append( - asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch)) - ) + tasks = [ + asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points)) + for type_name, field_name, batch_points in batches + ] - # Run all embedding requests in parallel await asyncio.gather(*tasks) return data_points diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 4fa8cfc75..03b5a25a5 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -1,17 +1,44 @@ -import asyncio +from collections import Counter +from typing import Optional, Dict, Any, List, Tuple, Union from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.shared.logging_utils import get_logger -from collections import Counter -from typing import Optional, Dict, Any, List, Tuple, Union -from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.models.EdgeType import EdgeType from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData +from cognee.tasks.storage.index_data_points import index_data_points logger = get_logger() +def _get_edge_text(item: dict) -> str: + """Extract edge text for embedding - prefers edge_text field with fallback.""" + if "edge_text" in item: + return item["edge_text"] + + if "relationship_name" in item: + return item["relationship_name"] + + return "" + + +def create_edge_type_datapoints(edges_data) -> list[EdgeType]: + """Transform raw edge data into EdgeType datapoints.""" + edge_texts = [ + _get_edge_text(item) + for edge in edges_data + for item in edge + if isinstance(item, dict) and "relationship_name" in item + ] + + edge_types = Counter(edge_texts) + + return [ + EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count) + for text, count in edge_types.items() + ] + + async def index_graph_edges( edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None, ): @@ -23,24 +50,17 @@ async def index_graph_edges( the `relationship_name` field. Steps: - 1. Initialize the vector engine and graph engine. - 2. Retrieve graph edge data and count relationship types (`relationship_name`). - 3. Create vector indexes for `relationship_name` if they don't exist. - 4. Transform the counted relationships into `EdgeType` objects. - 5. Index the transformed data points in the vector engine. + 1. Initialize the graph engine if needed and retrieve edge data. + 2. Transform edge data into EdgeType datapoints. + 3. Index the EdgeType datapoints using the standard indexing function. Raises: - RuntimeError: If initialization of the vector engine or graph engine fails. + RuntimeError: If initialization of the graph engine fails. Returns: None """ try: - created_indexes = {} - index_points = {} - - vector_engine = get_vector_engine() - if edges_data is None: graph_engine = await get_graph_engine() _, edges_data = await graph_engine.get_graph_data() @@ -51,47 +71,7 @@ async def index_graph_edges( logger.error("Failed to initialize engines: %s", e) raise RuntimeError("Initialization error") from e - edge_types = Counter( - item.get("relationship_name") - for edge in edges_data - for item in edge - if isinstance(item, dict) and "relationship_name" in item - ) - - for text, count in edge_types.items(): - edge = EdgeType( - id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count - ) - data_point_type = type(edge) - - for field_name in edge.metadata["index_fields"]: - index_name = f"{data_point_type.__name__}.{field_name}" - - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] - - indexed_data_point = edge.model_copy() - indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) - - # Get maximum batch size for embedding model - batch_size = vector_engine.embedding_engine.get_batch_size() - tasks: list[asyncio.Task] = [] - - for index_name, indexable_points in index_points.items(): - index_name, field_name = index_name.split(".") - - # Create embedding tasks to run in parallel later - for start in range(0, len(indexable_points), batch_size): - batch = indexable_points[start : start + batch_size] - - tasks.append(vector_engine.index_data_points(index_name, field_name, batch)) - - # Start all embedding tasks and wait for completion - await asyncio.gather(*tasks) + edge_type_datapoints = create_edge_type_datapoints(edges_data) + await index_data_points(edge_type_datapoints) return None diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py new file mode 100644 index 000000000..421bb81bd --- /dev/null +++ b/cognee/tests/integration/documents/CsvDocument_test.py @@ -0,0 +1,70 @@ +import os +import sys +import uuid +import pytest +import pathlib +from unittest.mock import patch + +from cognee.modules.chunking.CsvChunker import CsvChunker +from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument +from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine +from cognee.tests.integration.documents.async_gen_zip import async_gen_zip + +chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row") + + +GROUND_TRUTH = { + "chunk_size_10": [ + {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0}, + {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1}, + {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2}, + {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3}, + ], + "chunk_size_128": [ + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0}, + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1}, + ], +} + + +@pytest.mark.parametrize( + "input_file,chunk_size", + [("example_with_header.csv", 10), ("example_with_header.csv", 128)], +) +@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine) +@pytest.mark.asyncio +async def test_CsvDocument(mock_engine, input_file, chunk_size): + # Define file paths of test data + csv_file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, + "test_data", + input_file, + ) + + # Define test documents + csv_document = CsvDocument( + id=uuid.uuid4(), + name="example_with_header.csv", + raw_data_location=csv_file_path, + external_metadata="", + mime_type="text/csv", + ) + + # TEST CSV + ground_truth_key = f"chunk_size_{chunk_size}" + async for ground_truth, row_data in async_gen_zip( + GROUND_TRUTH[ground_truth_key], + csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size), + ): + assert ground_truth["token_count"] == row_data.chunk_size, ( + f'{ground_truth["token_count"] = } != {row_data.chunk_size = }' + ) + assert ground_truth["len_text"] == len(row_data.text), ( + f'{ground_truth["len_text"] = } != {len(row_data.text) = }' + ) + assert ground_truth["cut_type"] == row_data.cut_type, ( + f'{ground_truth["cut_type"] = } != {row_data.cut_type = }' + ) + assert ground_truth["chunk_index"] == row_data.chunk_index, ( + f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }' + ) diff --git a/cognee/tests/tasks/entity_extraction/entity_extraction_test.py b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py index 39e883e09..41a9254ca 100644 --- a/cognee/tests/tasks/entity_extraction/entity_extraction_test.py +++ b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py @@ -55,7 +55,7 @@ async def main(): classified_data = ingestion.classify(file) # data_id is the hash of original file contents + owner id to avoid duplicate data - data_id = ingestion.identify(classified_data, await get_default_user()) + data_id = await ingestion.identify(classified_data, await get_default_user()) await cognee.add(file_path) diff --git a/cognee/tests/test_add_docling_document.py b/cognee/tests/test_add_docling_document.py index 2c82af66f..c5aa4e9d1 100644 --- a/cognee/tests/test_add_docling_document.py +++ b/cognee/tests/test_add_docling_document.py @@ -39,12 +39,12 @@ async def main(): answer = await cognee.search("Do programmers change light bulbs?") assert len(answer) != 0 - lowercase_answer = answer[0].lower() + lowercase_answer = answer[0]["search_result"][0].lower() assert ("no" in lowercase_answer) or ("none" in lowercase_answer) answer = await cognee.search("What colours are there in the presentation table?") assert len(answer) != 0 - lowercase_answer = answer[0].lower() + lowercase_answer = answer[0]["search_result"][0].lower() assert ( ("red" in lowercase_answer) and ("blue" in lowercase_answer) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index ab68a8ef1..ddffe53a4 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -7,6 +7,7 @@ import requests from pathlib import Path import sys import uuid +import json class TestCogneeServerStart(unittest.TestCase): @@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase): ) } - payload = {"datasets": [dataset_name]} + ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]} add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50) if add_response.status_code not in [200, 201]: add_response.raise_for_status() + ontology_content = b""" + + + + + + + + + + + + + + + + A failure caused by physical components. + + + + + An error caused by software logic or configuration. + + + + A human being or individual. + + + + + Programmers + + + + Light Bulb + + + + Hardware Problem + + + """ + + ontology_response = requests.post( + "http://127.0.0.1:8000/api/v1/ontologies", + headers=headers, + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={ + "ontology_key": json.dumps([ontology_key]), + "description": json.dumps(["Test ontology"]), + }, + ) + self.assertEqual(ontology_response.status_code, 200) + # Cognify request url = "http://127.0.0.1:8000/api/v1/cognify" headers = { @@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase): if cognify_response.status_code not in [200, 201]: cognify_response.raise_for_status() + datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers) + + datasets = datasets_response.json() + dataset_id = None + for dataset in datasets: + if dataset["name"] == dataset_name: + dataset_id = dataset["id"] + break + + graph_response = requests.get( + f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers + ) + self.assertEqual(graph_response.status_code, 200) + + graph_data = graph_response.json() + ontology_nodes = [ + node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid") + ] + + self.assertGreater( + len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated" + ) + # TODO: Add test to verify cognify pipeline is complete before testing search # Search request diff --git a/cognee/tests/test_conversation_history.py b/cognee/tests/test_conversation_history.py index 30bb54ef1..783baf563 100644 --- a/cognee/tests/test_conversation_history.py +++ b/cognee/tests/test_conversation_history.py @@ -16,9 +16,11 @@ import cognee import pathlib from cognee.infrastructure.databases.cache import get_cache_engine +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.search.types import SearchType from cognee.shared.logging_utils import get_logger from cognee.modules.users.methods import get_default_user +from collections import Counter logger = get_logger() @@ -54,10 +56,10 @@ async def main(): """DataCo is a data analytics company. They help businesses make sense of their data.""" ) - await cognee.add(text_1, dataset_name) - await cognee.add(text_2, dataset_name) + await cognee.add(data=text_1, dataset_name=dataset_name) + await cognee.add(data=text_2, dataset_name=dataset_name) - await cognee.cognify([dataset_name]) + await cognee.cognify(datasets=[dataset_name]) user = await get_default_user() @@ -188,7 +190,6 @@ async def main(): f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}" ) - # Verify saved history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10) our_qa_summary = [ h for h in history_summary if h["question"] == "What are the key points about TechCorp?" @@ -228,6 +229,46 @@ async def main(): assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix" assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix" + from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, + ) + + logger.info("Starting persist_sessions_in_knowledge_graph tests") + + await persist_sessions_in_knowledge_graph_pipeline( + user=user, + session_ids=[session_id_1, session_id_2], + dataset=dataset_name, + run_in_background=False, + ) + + graph_engine = await get_graph_engine() + graph = await graph_engine.get_graph_data() + + type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) + + "Tests the correct number of NodeSet nodes after session persistence" + assert type_counts.get("NodeSet", 0) == 1, ( + f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1." + ) + + "Tests the correct number of DocumentChunk nodes after session persistence" + assert type_counts.get("DocumentChunk", 0) == 4, ( + f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)." + ) + + from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine + + vector_engine = get_vector_engine() + collection_size = await vector_engine.search( + collection_name="DocumentChunk_text", + query_text="test", + limit=1000, + ) + assert len(collection_size) == 4, ( + f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}" + ) + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv new file mode 100644 index 000000000..dc900e5ef --- /dev/null +++ b/cognee/tests/test_data/example_with_header.csv @@ -0,0 +1,3 @@ +id,name,age,city,country +1,Eric,30,Beijing,China +2,Joe,35,Berlin,Germany diff --git a/cognee/tests/test_edge_ingestion.py b/cognee/tests/test_edge_ingestion.py index 5b23f7819..0d1407fab 100755 --- a/cognee/tests/test_edge_ingestion.py +++ b/cognee/tests/test_edge_ingestion.py @@ -52,6 +52,33 @@ async def test_edge_ingestion(): edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) + "Tests edge_text presence and format" + contains_edges = [edge for edge in graph[1] if edge[2] == "contains"] + assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification" + + edge_properties = contains_edges[0][3] + assert "edge_text" in edge_properties, "Expected edge_text in edge properties" + + edge_text = edge_properties["edge_text"] + assert "relationship_name: contains" in edge_text, ( + f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}" + ) + assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}" + assert "entity_description:" in edge_text, ( + f"Expected 'entity_description:' in edge_text, got: {edge_text}" + ) + + all_edge_texts = [ + edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3] + ] + expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"] + found_entity = any( + any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts + ) + assert found_entity, ( + f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}" + ) + "Tests the presence of basic nested edges" for basic_nested_edge in basic_nested_edges: assert edge_type_counts.get(basic_nested_edge, 0) >= 1, ( diff --git a/cognee/tests/test_feedback_enrichment.py b/cognee/tests/test_feedback_enrichment.py index 02d90db32..378cb0e45 100644 --- a/cognee/tests/test_feedback_enrichment.py +++ b/cognee/tests/test_feedback_enrichment.py @@ -133,7 +133,7 @@ async def main(): extraction_tasks=extraction_tasks, enrichment_tasks=enrichment_tasks, data=[{}], - dataset="feedback_enrichment_test_memify", + dataset=dataset_name, ) nodes_after, edges_after = await graph_engine.get_graph_data() diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 81f81ee61..893b836c0 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -90,15 +90,17 @@ async def main(): ) search_results = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?" + query_type=SearchType.GRAPH_COMPLETION, + query_text="What information do you contain?", + dataset_ids=[pipeline_run_obj.dataset_id], ) - assert "Mark" in search_results[0], ( + assert "Mark" in search_results[0]["search_result"][0], ( "Failed to update document, no mention of Mark in search results" ) - assert "Cindy" in search_results[0], ( + assert "Cindy" in search_results[0]["search_result"][0], ( "Failed to update document, no mention of Cindy in search results" ) - assert "Artificial intelligence" not in search_results[0], ( + assert "Artificial intelligence" not in search_results[0]["search_result"][0], ( "Failed to update document, Artificial intelligence still mentioned in search results" ) diff --git a/cognee/tests/test_load.py b/cognee/tests/test_load.py new file mode 100644 index 000000000..b38466bc7 --- /dev/null +++ b/cognee/tests/test_load.py @@ -0,0 +1,62 @@ +import os +import pathlib +import asyncio +import time + +import cognee +from cognee.modules.search.types import SearchType +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +async def process_and_search(num_of_searches): + start_time = time.time() + + await cognee.cognify() + + await asyncio.gather( + *[ + cognee.search( + query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION + ) + for _ in range(num_of_searches) + ] + ) + + end_time = time.time() + + return end_time - start_time + + +async def main(): + data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load") + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load") + cognee.config.system_root_directory(cognee_directory_path) + + num_of_pdfs = 10 + num_of_reps = 5 + upper_boundary_minutes = 10 + average_minutes = 8 + + recorded_times = [] + for _ in range(num_of_reps): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + s3_input = "s3://cognee-test-load-s3-bucket" + await cognee.add(s3_input) + + recorded_times.append(await process_and_search(num_of_pdfs)) + + average_recorded_time = sum(recorded_times) / len(recorded_times) + + assert average_recorded_time <= average_minutes * 60 + + assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/test_multi_tenancy.py b/cognee/tests/test_multi_tenancy.py new file mode 100644 index 000000000..7cdcda8d8 --- /dev/null +++ b/cognee/tests/test_multi_tenancy.py @@ -0,0 +1,165 @@ +import cognee +import pytest + +from cognee.modules.users.exceptions import PermissionDeniedError +from cognee.modules.users.tenants.methods import select_tenant +from cognee.modules.users.methods import get_user +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType +from cognee.modules.users.methods import create_user +from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets +from cognee.modules.users.roles.methods import add_user_to_role +from cognee.modules.users.roles.methods import create_role +from cognee.modules.users.tenants.methods import create_tenant +from cognee.modules.users.tenants.methods import add_user_to_tenant +from cognee.modules.engine.operations.setup import setup +from cognee.shared.logging_utils import setup_logging, CRITICAL + +logger = get_logger() + + +async def main(): + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # Set up the necessary databases and tables for user management. + await setup() + + # Add document for user_1, add it under dataset name AI + text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. + At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages + this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the + preparation and manipulation of quantum state""" + + print("Creating user_1: user_1@example.com") + user_1 = await create_user("user_1@example.com", "example") + await cognee.add([text], dataset_name="AI", user=user_1) + + print("\nCreating user_2: user_2@example.com") + user_2 = await create_user("user_2@example.com", "example") + + # Run cognify for both datasets as the appropriate user/owner + print("\nCreating different datasets for user_1 (AI dataset) and user_2 (QUANTUM dataset)") + ai_cognify_result = await cognee.cognify(["AI"], user=user_1) + + # Extract dataset_ids from cognify results + def extract_dataset_id_from_cognify(cognify_result): + """Extract dataset_id from cognify output dictionary""" + for dataset_id, pipeline_result in cognify_result.items(): + return dataset_id # Return the first dataset_id + return None + + # Get dataset IDs from cognify results + # Note: When we want to work with datasets from other users (search, add, cognify and etc.) we must supply dataset + # information through dataset_id using dataset name only looks for datasets owned by current user + ai_dataset_id = extract_dataset_id_from_cognify(ai_cognify_result) + + # We can see here that user_1 can read his own dataset (AI dataset) + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + datasets=[ai_dataset_id], + ) + + # Verify that user_2 cannot access user_1's dataset without permission + with pytest.raises(PermissionDeniedError): + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_2, + datasets=[ai_dataset_id], + ) + + # Create new tenant and role, add user_2 to tenant and role + tenant_id = await create_tenant("CogneeLab", user_1.id) + await select_tenant(user_id=user_1.id, tenant_id=tenant_id) + role_id = await create_role(role_name="Researcher", owner_id=user_1.id) + await add_user_to_tenant( + user_id=user_2.id, tenant_id=tenant_id, owner_id=user_1.id, set_as_active_tenant=True + ) + await add_user_to_role(user_id=user_2.id, role_id=role_id, owner_id=user_1.id) + + # Assert that user_1 cannot give permissions on his dataset to role before switching to the correct tenant + # AI dataset was made with default tenant and not CogneeLab tenant + with pytest.raises(PermissionDeniedError): + await authorized_give_permission_on_datasets( + role_id, + [ai_dataset_id], + "read", + user_1.id, + ) + + # We need to refresh the user object with changes made when switching tenants + user_1 = await get_user(user_1.id) + await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1) + ai_cognee_lab_cognify_result = await cognee.cognify(["AI_COGNEE_LAB"], user=user_1) + + ai_cognee_lab_dataset_id = extract_dataset_id_from_cognify(ai_cognee_lab_cognify_result) + + await authorized_give_permission_on_datasets( + role_id, + [ai_cognee_lab_dataset_id], + "read", + user_1.id, + ) + + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_2, + dataset_ids=[ai_cognee_lab_dataset_id], + ) + for result in search_results: + print(f"{result}\n") + + # Let's test changing tenants + tenant_id = await create_tenant("CogneeLab2", user_1.id) + await select_tenant(user_id=user_1.id, tenant_id=tenant_id) + + user_1 = await get_user(user_1.id) + await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1) + await cognee.cognify(["AI_COGNEE_LAB"], user=user_1) + + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + ) + + # Assert only AI_COGNEE_LAB dataset from CogneeLab2 tenant is visible as the currently selected tenant + assert len(search_results) == 1, ( + f"Search results must only contain one dataset from current tenant: {search_results}" + ) + assert search_results[0]["dataset_name"] == "AI_COGNEE_LAB", ( + f"Dict must contain dataset name 'AI_COGNEE_LAB': {search_results[0]}" + ) + assert search_results[0]["dataset_tenant_id"] == user_1.tenant_id, ( + f"Dataset tenant_id must be same as user_1 tenant_id: {search_results[0]}" + ) + + # Switch back to no tenant (default tenant) + await select_tenant(user_id=user_1.id, tenant_id=None) + # Refresh user_1 object + user_1 = await get_user(user_1.id) + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + ) + assert len(search_results) == 1, ( + f"Search results must only contain one dataset from default tenant: {search_results}" + ) + assert search_results[0]["dataset_name"] == "AI", ( + f"Dict must contain dataset name 'AI': {search_results[0]}" + ) + + +if __name__ == "__main__": + import asyncio + + logger = setup_logging(log_level=CRITICAL) + asyncio.run(main()) diff --git a/cognee/tests/test_parallel_databases.py b/cognee/tests/test_parallel_databases.py index 9a590921a..3164206ed 100755 --- a/cognee/tests/test_parallel_databases.py +++ b/cognee/tests/test_parallel_databases.py @@ -33,11 +33,13 @@ async def main(): "vector_db_url": "cognee1.test", "vector_db_key": "", "vector_db_provider": "lancedb", + "vector_db_name": "", } task_2_config = { "vector_db_url": "cognee2.test", "vector_db_key": "", "vector_db_provider": "lancedb", + "vector_db_name": "", } task_1_graph_config = { diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 2b69ce854..ae06e7c5d 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -1,6 +1,5 @@ import pathlib import os -from typing import List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import ( get_migration_relational_engine, @@ -10,7 +9,7 @@ from cognee.infrastructure.databases.vector.pgvector import ( create_db_and_tables as create_pgvector_db_and_tables, ) from cognee.tasks.ingestion import migrate_relational_database -from cognee.modules.search.types import SearchResult, SearchType +from cognee.modules.search.types import SearchType import cognee @@ -27,6 +26,9 @@ def normalize_node_name(node_name: str) -> str: async def setup_test_db(): + # Disable backend access control to migrate relational data + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) @@ -271,6 +273,55 @@ async def test_schema_only_migration(): print(f"Edge counts: {edge_counts}") +async def test_search_result_quality(): + from cognee.infrastructure.databases.relational import ( + get_migration_relational_engine, + ) + + # Get relational database with original data + migration_engine = get_migration_relational_engine() + from sqlalchemy import text + + async with migration_engine.engine.connect() as conn: + result = await conn.execute( + text(""" + SELECT + c.CustomerId, + c.FirstName, + c.LastName, + GROUP_CONCAT(i.InvoiceId, ',') AS invoice_ids + FROM Customer AS c + LEFT JOIN Invoice AS i ON c.CustomerId = i.CustomerId + GROUP BY c.CustomerId, c.FirstName, c.LastName + """) + ) + + for row in result: + # Get expected invoice IDs from relational DB for each Customer + customer_id = row.CustomerId + invoice_ids = row.invoice_ids.split(",") if row.invoice_ids else [] + print(f"Relational DB Customer {customer_id}: {invoice_ids}") + + # Use Cognee search to get invoice IDs for the same Customer but by providing Customer name + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=f"List me all the invoices of Customer:{row.FirstName} {row.LastName}.", + top_k=50, + system_prompt="Just return me the invoiceID as a number without any text. This is an example output: ['1', '2', '3']. Where 1, 2, 3 are invoiceIDs of an invoice", + ) + print(f"Cognee search result: {search_results}") + + import ast + + lst = ast.literal_eval(search_results[0]) # converts string -> Python list + # Transfrom both lists to int for comparison, sorting and type consistency + lst = sorted([int(x) for x in lst]) + invoice_ids = sorted([int(x) for x in invoice_ids]) + assert lst == invoice_ids, ( + f"Search results {lst} do not match expected invoice IDs {invoice_ids} for Customer:{customer_id}" + ) + + async def test_migration_sqlite(): database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") @@ -283,6 +334,7 @@ async def test_migration_sqlite(): ) await relational_db_migration() + await test_search_result_quality() await test_schema_only_migration() diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index e24abd0f5..bd11dc62e 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -146,7 +146,13 @@ async def main(): assert len(search_results) == 1, ( f"{name}: expected single-element list, got {len(search_results)}" ) - text = search_results[0] + + from cognee.context_global_variables import backend_access_control_enabled + + if backend_access_control_enabled(): + text = search_results[0]["search_result"][0] + else: + text = search_results[0] assert isinstance(text, str), f"{name}: element should be a string" assert text.strip(), f"{name}: string should not be empty" assert "netherlands" in text.lower(), ( diff --git a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py index 2eabee91a..6cc37ef38 100644 --- a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py +++ b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py @@ -1,3 +1,4 @@ +import os import pytest from unittest.mock import patch, AsyncMock, MagicMock from uuid import uuid4 @@ -5,8 +6,6 @@ from fastapi.testclient import TestClient from types import SimpleNamespace import importlib -from cognee.api.client import app - # Fixtures for reuse across test classes @pytest.fixture @@ -32,6 +31,10 @@ def mock_authenticated_user(): ) +# To turn off authentication we need to set the environment variable before importing the module +# Also both require_authentication and backend access control must be false +os.environ["REQUIRE_AUTHENTICATION"] = "false" +os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") @@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints: @pytest.fixture def client(self): + from cognee.api.client import app + """Create a test client.""" return TestClient(app) @@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior: @pytest.fixture def client(self): + from cognee.api.client import app + return TestClient(app) @pytest.mark.parametrize( @@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling: @pytest.fixture def client(self): + from cognee.api.client import app + return TestClient(app) @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) @@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling: # The exact error message may vary depending on the actual database connection # The important thing is that we get a 500 error when user creation fails - def test_current_environment_configuration(self): + def test_current_environment_configuration(self, client): """Test that current environment configuration is working properly.""" # This tests the actual module state without trying to change it from cognee.modules.users.methods.get_authenticated_user import ( diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py new file mode 100644 index 000000000..af3a4d90e --- /dev/null +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -0,0 +1,272 @@ +import pytest +import uuid +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock, AsyncMock +from types import SimpleNamespace +import importlib +from cognee.api.client import app + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + + +@pytest.fixture +def client(): + return TestClient(app) + + +@pytest.fixture +def mock_user(): + user = Mock() + user.id = "test-user-123" + return user + + +@pytest.fixture +def mock_default_user(): + """Mock default user for testing.""" + return SimpleNamespace( + id=str(uuid.uuid4()), + email="default@example.com", + is_active=True, + tenant_id=str(uuid.uuid4()), + ) + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): + """Test successful ontology upload""" + import json + + mock_get_default_user.return_value = mock_default_user + ontology_content = ( + b"" + ) + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + + response = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): + """Test 400 response for non-.owl files""" + mock_get_default_user.return_value = mock_default_user + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.txt", b"not xml")}, + data={"ontology_key": unique_key}, + ) + assert response.status_code == 400 + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): + """Test 400 response for missing file or key""" + import json + + mock_get_default_user.return_value = mock_default_user + # Missing file + response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) + assert response.status_code == 400 + + # Missing key + response = client.post( + "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))] + ) + assert response.status_code == 400 + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): + """Test behavior when default user is provided (no explicit authentication)""" + import json + + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + mock_get_default_user.return_value = mock_default_user + response = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test.owl", b"", "application/xml"))], + data={"ontology_key": json.dumps([unique_key])}, + ) + + # The current system provides a default user when no explicit authentication is given + # This test verifies the system works with conditional authentication + assert response.status_code == 200 + data = response.json() + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test uploading multiple ontology files in single request""" + import io + + mock_get_default_user.return_value = mock_default_user + # Create mock files + file1_content = b"" + file2_content = b"" + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": '["vehicles", "manufacturers"]', + "descriptions": '["Base vehicles", "Car manufacturers"]', + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert "uploaded_ontologies" in result + assert len(result["uploaded_ontologies"]) == 2 + assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles" + assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): + """Test that upload endpoint accepts array parameters""" + import io + import json + + mock_get_default_user.return_value = mock_default_user + file_content = b"" + + files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["single_key"]), + "descriptions": json.dumps(["Single ontology"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test cognify endpoint accepts multiple ontology keys""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["ontology1", "ontology2"], # Array instead of string + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + + # Should not fail due to ontology_key type + assert response.status_code in [200, 400, 409] # May fail for other reasons, not type + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): + """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" + import io + import json + + mock_get_default_user.return_value = mock_default_user + # Step 1: Upload multiple ontologies + file1_content = b""" + + + """ + + file2_content = b""" + + + """ + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["vehicles", "manufacturers"]), + "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]), + } + + upload_response = client.post("/api/v1/ontologies", files=files, data=data) + assert upload_response.status_code == 200 + + # Step 2: Verify ontologies are listed + list_response = client.get("/api/v1/ontologies") + assert list_response.status_code == 200 + ontologies = list_response.json() + assert "vehicles" in ontologies + assert "manufacturers" in ontologies + + # Step 3: Test cognify with multiple ontologies + cognify_payload = { + "datasets": ["test_dataset"], + "ontology_key": ["vehicles", "manufacturers"], + "run_in_background": False, + } + + cognify_response = client.post("/api/v1/cognify", json=cognify_payload) + # Should not fail due to ontology handling (may fail for dataset reasons) + assert cognify_response.status_code != 400 # Not a validation error + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): + """Test error handling for invalid multifile uploads""" + import io + import json + + # Test mismatched array lengths + file_content = b"" + files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file + "descriptions": json.dumps(["desc1"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Number of keys must match number of files" in response.json()["error"] + + # Test duplicate keys + files = [ + ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), + ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["duplicate", "duplicate"]), + "descriptions": json.dumps(["desc1", "desc2"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Duplicate ontology keys not allowed" in response.json()["error"] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): + """Test cognify with non-existent ontology key""" + mock_get_default_user.return_value = mock_default_user + + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["nonexistent_key"], + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + assert response.status_code == 409 + assert "Ontology key 'nonexistent_key' not found" in response.json()["error"] diff --git a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py index a8d3bda82..837a9955c 100644 --- a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +++ b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py @@ -8,6 +8,7 @@ def test_cache_config_defaults(): """Test that CacheConfig has the correct default values.""" config = CacheConfig() + assert config.cache_backend == "fs" assert config.caching is False assert config.shared_kuzu_lock is False assert config.cache_host == "localhost" @@ -19,6 +20,7 @@ def test_cache_config_defaults(): def test_cache_config_custom_values(): """Test that CacheConfig accepts custom values.""" config = CacheConfig( + cache_backend="redis", caching=True, shared_kuzu_lock=True, cache_host="redis.example.com", @@ -27,6 +29,7 @@ def test_cache_config_custom_values(): agentic_lock_timeout=180, ) + assert config.cache_backend == "redis" assert config.caching is True assert config.shared_kuzu_lock is True assert config.cache_host == "redis.example.com" @@ -38,6 +41,7 @@ def test_cache_config_custom_values(): def test_cache_config_to_dict(): """Test the to_dict method returns all configuration values.""" config = CacheConfig( + cache_backend="fs", caching=True, shared_kuzu_lock=True, cache_host="test-host", @@ -49,6 +53,7 @@ def test_cache_config_to_dict(): config_dict = config.to_dict() assert config_dict == { + "cache_backend": "fs", "caching": True, "shared_kuzu_lock": True, "cache_host": "test-host", diff --git a/cognee/tests/unit/infrastructure/databases/test_index_data_points.py b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py new file mode 100644 index 000000000..21a5695de --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py @@ -0,0 +1,27 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from cognee.tasks.storage.index_data_points import index_data_points +from cognee.infrastructure.engine import DataPoint + + +class TestDataPoint(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + +@pytest.mark.asyncio +async def test_index_data_points_calls_vector_engine(): + """Test that index_data_points creates vector index and indexes data.""" + data_points = [TestDataPoint(name="test1")] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + + with patch.dict( + index_data_points.__globals__, + {"get_vector_engine": lambda: mock_vector_engine}, + ): + await index_data_points(data_points) + + assert mock_vector_engine.create_vector_index.await_count >= 1 + assert mock_vector_engine.index_data_points.await_count >= 1 diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py index 48bbc53e3..cee0896c2 100644 --- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py @@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges @pytest.mark.asyncio async def test_index_graph_edges_success(): - """Test that index_graph_edges uses the index datapoints and creates vector index.""" - # Create the mocks for the graph and vector engines. + """Test that index_graph_edges retrieves edges and delegates to index_data_points.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = ( None, @@ -15,26 +14,23 @@ async def test_index_graph_edges_success(): [{"relationship_name": "rel2"}], ], ) - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + mock_index_data_points = AsyncMock() - # Patch the globals of the function so that when it does: - # vector_engine = get_vector_engine() - # graph_engine = await get_graph_engine() - # it uses the mocked versions. with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() - # Assertions on the mock calls. mock_graph_engine.get_graph_data.assert_awaited_once() - assert mock_vector_engine.create_vector_index.await_count == 1 - assert mock_vector_engine.index_data_points.await_count == 1 + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 2 + assert all(hasattr(item, "relationship_name") for item in call_args) @pytest.mark.asyncio @@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships(): """Test that index_graph_edges handles empty relationships correctly.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = (None, []) - mock_vector_engine = AsyncMock() + mock_index_data_points = AsyncMock() with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() mock_graph_engine.get_graph_data.assert_awaited_once() - mock_vector_engine.create_vector_index.assert_not_awaited() - mock_vector_engine.index_data_points.assert_not_awaited() + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 0 @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker.py b/cognee/tests/unit/modules/chunking/test_text_chunker.py new file mode 100644 index 000000000..d535f74b0 --- /dev/null +++ b/cognee/tests/unit/modules/chunking/test_text_chunker.py @@ -0,0 +1,248 @@ +"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence.""" + +import pytest +from uuid import uuid4 + +from cognee.modules.chunking.TextChunker import TextChunker +from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap +from cognee.modules.data.processing.document_types import Document + + +@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"]) +def chunker_class(request): + """Parametrize tests to run against both implementations.""" + return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap + + +@pytest.fixture +def make_text_generator(): + """Factory for async text generators.""" + + def _factory(*texts): + async def gen(): + for text in texts: + yield text + + return gen + + return _factory + + +async def collect_chunks(chunker): + """Consume async generator and return list of chunks.""" + chunks = [] + async for chunk in chunker.read(): + chunks.append(chunk) + return chunks + + +@pytest.mark.asyncio +async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator): + """Empty input should yield no chunks.""" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator("") + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 0, "Empty input should produce no chunks" + + +@pytest.mark.asyncio +async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator): + """Whitespace-only input should produce exactly one chunk with unchanged text.""" + whitespace_text = " \n\t \r\n " + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(whitespace_text) + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk" + assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + + +@pytest.mark.asyncio +async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator): + """Single paragraph below limit should emit exactly one chunk.""" + text = "This is a short paragraph." + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk" + assert chunks[0].text == text, "Chunk text should match input" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[0].chunk_size > 0, "Chunk should have positive size" + + +@pytest.mark.asyncio +async def test_oversized_paragraph_gets_emitted_as_a_single_chunk( + chunker_class, make_text_generator +): + """Oversized paragraph from chunk_by_paragraph should be emitted as single chunk.""" + text = ("A" * 1500) + ". Next sentence." + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=50) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)" + assert chunks[0].chunk_size > 50, "First chunk should be oversized" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + + +@pytest.mark.asyncio +async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator): + """First paragraph near limit plus small paragraph should produce two separate chunks.""" + first_para = " ".join(["word"] * 5) + second_para = "Short text." + text = first_para + " " + second_para + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=10) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 2, "Should produce 2 chunks due to overflow" + assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph" + assert chunks[1].text.strip() == second_para, ( + "Second chunk should contain only second paragraph" + ) + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + + +@pytest.mark.asyncio +async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator): + """Multiple small paragraphs should batch together with joiner spaces counted.""" + paragraphs = [" ".join(["word"] * 12) for _ in range(40)] + text = " ".join(paragraphs) + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=49) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 20, ( + "Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)" + ) + assert all(c.chunk_index == i for i, c in enumerate(chunks)), ( + "Chunk indices should be sequential" + ) + all_text = " ".join(chunk.text.strip() for chunk in chunks) + expected_text = " ".join(paragraphs) + assert all_text == expected_text, "All paragraph text should be preserved with correct spacing" + + +@pytest.mark.asyncio +async def test_alternating_large_and_small_paragraphs_dont_batch( + chunker_class, make_text_generator +): + """Alternating near-max and small paragraphs should each become separate chunks.""" + large1 = "word" * 15 + "." + small1 = "Short." + large2 = "word" * 15 + "." + small2 = "Tiny." + text = large1 + " " + small1 + " " + large2 + " " + small2 + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + max_chunk_size = 10 + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 4, "Should produce multiple chunks" + assert all(c.chunk_index == i for i, c in enumerate(chunks)), ( + "Chunk indices should be sequential" + ) + assert chunks[0].chunk_size > max_chunk_size, ( + "First chunk should be oversized (large paragraph)" + ) + assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)" + assert chunks[2].chunk_size > max_chunk_size, ( + "Third chunk should be oversized (large paragraph)" + ) + assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)" + + +@pytest.mark.asyncio +async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator): + """Running chunker twice on identical input should produce identical indices and IDs.""" + sentence1 = "one " * 4 + ". " + sentence2 = "two " * 4 + ". " + sentence3 = "one " * 4 + ". " + sentence4 = "two " * 4 + ". " + text = sentence1 + sentence2 + sentence3 + sentence4 + doc_id = uuid4() + max_chunk_size = 20 + + document1 = Document( + id=doc_id, + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text1 = make_text_generator(text) + chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size) + chunks1 = await collect_chunks(chunker1) + + document2 = Document( + id=doc_id, + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text2 = make_text_generator(text) + chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size) + chunks2 = await collect_chunks(chunker2) + + assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)" + assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)" + assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]" + assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]" + assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic" + assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic" + assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run" diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py new file mode 100644 index 000000000..9d7be6936 --- /dev/null +++ b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py @@ -0,0 +1,324 @@ +"""Unit tests for TextChunkerWithOverlap overlap behavior.""" + +import sys +import pytest +from uuid import uuid4 +from unittest.mock import patch + +from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap +from cognee.modules.data.processing.document_types import Document +from cognee.tasks.chunks import chunk_by_paragraph + + +@pytest.fixture +def make_text_generator(): + """Factory for async text generators.""" + + def _factory(*texts): + async def gen(): + for text in texts: + yield text + + return gen + + return _factory + + +@pytest.fixture +def make_controlled_chunk_data(): + """Factory for controlled chunk_data generators.""" + + def _factory(*sentences, chunk_size_per_sentence=10): + def _chunk_data(text): + return [ + { + "text": sentence, + "chunk_size": chunk_size_per_sentence, + "cut_type": "sentence", + "chunk_id": uuid4(), + } + for sentence in sentences + ] + + return _chunk_data + + return _factory + + +@pytest.mark.asyncio +async def test_half_overlap_preserves_content_across_chunks( + make_text_generator, make_controlled_chunk_data +): + """With 50% overlap, consecutive chunks should share half their content.""" + s1 = "one" + s2 = "two" + s3 = "three" + s4 = "four" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.5, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)" + assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]" + assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2" + assert "two" in chunks[1].text and "three" in chunks[1].text, ( + "Chunk 1 should contain s2 (overlap) and s3" + ) + assert "three" in chunks[2].text and "four" in chunks[2].text, ( + "Chunk 2 should contain s3 (overlap) and s4" + ) + + +@pytest.mark.asyncio +async def test_zero_overlap_produces_no_duplicate_content( + make_text_generator, make_controlled_chunk_data +): + """With 0% overlap, no content should appear in multiple chunks.""" + s1 = "one" + s2 = "two" + s3 = "three" + s4 = "four" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.0, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)" + assert "one" in chunks[0].text and "two" in chunks[0].text, ( + "First chunk should contain s1 and s2" + ) + assert "three" in chunks[1].text and "four" in chunks[1].text, ( + "Second chunk should contain s3 and s4" + ) + assert "two" not in chunks[1].text and "three" not in chunks[0].text, ( + "No overlap: end of chunk 0 should not appear in chunk 1" + ) + + +@pytest.mark.asyncio +async def test_small_overlap_ratio_creates_minimal_overlap( + make_text_generator, make_controlled_chunk_data +): + """With 25% overlap ratio, chunks should have minimal overlap.""" + s1 = "alpha" + s2 = "beta" + s3 = "gamma" + s4 = "delta" + s5 = "epsilon" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=40, + chunk_overlap_ratio=0.25, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks" + assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]" + assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), ( + "Chunk 0 should contain s1 through s4" + ) + assert s4 in chunks[1].text and s5 in chunks[1].text, ( + "Chunk 1 should contain overlap s4 and new content s5" + ) + + +@pytest.mark.asyncio +async def test_high_overlap_ratio_creates_significant_overlap( + make_text_generator, make_controlled_chunk_data +): + """With 75% overlap ratio, consecutive chunks should share most content.""" + s1 = "red" + s2 = "blue" + s3 = "green" + s4 = "yellow" + s5 = "purple" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.75, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap" + assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]" + assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), ( + "Chunk 0 should contain s1, s2, s3, s4" + ) + assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), ( + "Chunk 1 should contain s2, s3, s4 (overlap) and s5" + ) + + +@pytest.mark.asyncio +async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data): + """Text that fits in one chunk should produce exactly one chunk, no overlap artifact.""" + s1 = "alpha" + s2 = "beta" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.5, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 1, ( + "Should produce exactly 1 chunk when content fits within max_chunk_size" + ) + assert chunks[0].chunk_index == 0, "Single chunk should have index 0" + assert "alpha" in chunks[0].text and "beta" in chunks[0].text, ( + "Single chunk should contain all content" + ) + + +@pytest.mark.asyncio +async def test_paragraph_chunking_with_overlap(make_text_generator): + """Test that chunk_by_paragraph integration produces 25% overlap between chunks.""" + + def mock_get_embedding_engine(): + class MockEngine: + tokenizer = None + + return MockEngine() + + chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence") + + max_chunk_size = 20 + overlap_ratio = 0.25 # 5 token overlap + paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2 + + text = ( + "A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9) + "B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19) + "C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29) + "D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39) + "E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49) + ) + + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + + get_text = make_text_generator(text) + + def get_chunk_data(text_input): + return chunk_by_paragraph( + text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True + ) + + with patch.object( + chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine + ): + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=max_chunk_size, + chunk_overlap_ratio=overlap_ratio, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}" + + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + assert chunks[2].chunk_index == 2, "Third chunk should have index 2" + + assert "A0" in chunks[0].text, "Chunk 0 should start with A0" + assert "A9" in chunks[0].text, "Chunk 0 should contain A9" + assert "B0" in chunks[0].text, "Chunk 0 should contain B0" + assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)" + + assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section" + assert "C" in chunks[1].text, "Chunk 1 should contain C section" + assert "D" in chunks[1].text, "Chunk 1 should contain D section" + + assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section" + assert "E0" in chunks[2].text, "Chunk 2 should contain E0" + assert "E9" in chunks[2].text, "Chunk 2 should end with E9" + + chunk_0_end_words = chunks[0].text.split()[-4:] + chunk_1_words = chunks[1].text.split() + overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words) + assert overlap_0_1, ( + f"No overlap detected between chunks 0 and 1. " + f"Chunk 0 ends with: {chunk_0_end_words}, " + f"Chunk 1 starts with: {chunk_1_words[:6]}" + ) + + chunk_1_end_words = chunks[1].text.split()[-4:] + chunk_2_words = chunks[2].text.split() + overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words) + assert overlap_1_2, ( + f"No overlap detected between chunks 1 and 2. " + f"Chunk 1 ends with: {chunk_1_end_words}, " + f"Chunk 2 starts with: {chunk_2_words[:6]}" + ) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 37ba113b5..1d2b79cf9 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": np.inf} + assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": np.inf, "weight": 10} + assert edge.attributes == {"vector_distance": 3.5, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 6888648c3..711479387 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph @@ -11,6 +12,30 @@ def setup_graph(): return CogneeGraph() +@pytest.fixture +def mock_adapter(): + """Fixture to create a mock adapter for database operations.""" + adapter = AsyncMock() + return adapter + + +@pytest.fixture +def mock_vector_engine(): + """Fixture to create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + def test_add_node_success(setup_graph): """Test successful addition of a node.""" graph = setup_graph @@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph): graph = setup_graph with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): graph.get_edges_from_node("nonexistent") + + +@pytest.mark.asyncio +async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter): + """Test projecting a full graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1", "description": "First node"}), + ("2", {"name": "Node2", "description": "Second node"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "description"], + edge_properties_to_project=["relationship_name"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + assert graph.get_node("1") is not None + assert graph.get_node("2") is not None + assert graph.edges[0].node1.id == "1" + assert graph.edges[0].node2.id == "2" + + +@pytest.mark.asyncio +async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter): + """Test projecting an ID-filtered graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ("2", {"name": "Node2"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + relevant_ids_to_filter=["1", "2"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + mock_adapter.get_id_filtered_graph_data.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter): + """Test projecting a nodeset subgraph filtered by node type and name.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Alice", "type": "Person"}), + ("2", {"name": "Bob", "type": "Person"}), + ] + edges_data = [ + ("1", "2", "KNOWS", {"relationship_name": "knows"}), + ] + + mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "type"], + edge_properties_to_project=["relationship_name"], + node_type="Person", + node_name=["Alice"], + ) + + assert len(graph.nodes) == 2 + assert graph.get_node("1") is not None + assert len(graph.edges) == 1 + mock_adapter.get_nodeset_subgraph.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): + """Test projecting empty graph raises EntityNotFoundError.""" + graph = setup_graph + + mock_adapter.get_graph_data = AsyncMock(return_value=([], [])) + + with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + +@pytest.mark.asyncio +async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): + """Test that edges referencing missing nodes raise error.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ] + edges_data = [ + ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + ) + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes(setup_graph): + """Test mapping vector distances to graph nodes.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_node_coverage(setup_graph): + """Test mapping vector distances when only some nodes have results.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + node3 = Node("3", {"name": "Node3"}) + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_multiple_categories(setup_graph): + """Test mapping vector distances from multiple collection categories.""" + graph = setup_graph + + # Create nodes + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ], + "TextSummary_text": [ + MockScoredResult("3", 0.92), + ], + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 0.92 + assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine): + """Test mapping vector distances to edges when edge_distances provided.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine): + """Test mapping edge distances when searching for them.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + mock_vector_engine.search.return_value = [ + MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=None, + ) + + mock_vector_engine.search.assert_called_once() + assert graph.edges[0].attributes.get("vector_distance") == 0.88 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine): + """Test mapping edge distances when only some edges have results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[1].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_edges_fallback_to_relationship_type( + setup_graph, mock_vector_engine +): + """Test that edge mapping falls back to relationship_type when edge_text is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"relationship_type": "KNOWS"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.85 + + +@pytest.mark.asyncio +async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine): + """Test edge mapping when no edges match the distance results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine): + """Test that invalid query vector raises error.""" + graph = setup_graph + + with pytest.raises(ValueError, match="Failed to generate query embedding"): + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[], + edge_distances=None, + ) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances(setup_graph): + """Test calculating top triplet importances by score.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + + node1.add_attribute("vector_distance", 0.9) + node2.add_attribute("vector_distance", 0.8) + node3.add_attribute("vector_distance", 0.7) + node4.add_attribute("vector_distance", 0.6) + + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + edge3 = Edge(node3, node4) + + edge1.add_attribute("vector_distance", 0.85) + edge2.add_attribute("vector_distance", 0.75) + edge3.add_attribute("vector_distance", 0.65) + + graph.add_edge(edge1) + graph.add_edge(edge2) + graph.add_edge(edge3) + + top_triplets = await graph.calculate_top_triplet_importances(k=2) + + assert len(top_triplets) == 2 + + assert top_triplets[0] == edge3 + assert top_triplets[1] == edge2 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_default_distances(setup_graph): + """Test calculating importances when nodes/edges have no vector distances.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2) + graph.add_edge(edge) + + top_triplets = await graph.calculate_top_triplet_importances(k=1) + + assert len(top_triplets) == 1 + assert top_triplets[0] == edge diff --git a/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py new file mode 100644 index 000000000..8c2448287 --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py @@ -0,0 +1,111 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.tasks.memify.cognify_session import cognify_session +from cognee.exceptions import CogneeValidationError, CogneeSystemError + + +@pytest.mark.asyncio +async def test_cognify_session_success(): + """Test successful cognification of session data.""" + session_data = ( + "Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence" + ) + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data, dataset_id="123") + + mock_add.assert_called_once_with( + session_data, dataset_id="123", node_set=["user_sessions_from_cache"] + ) + mock_cognify.assert_called_once() + + +@pytest.mark.asyncio +async def test_cognify_session_empty_string(): + """Test cognification fails with empty string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session("") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_whitespace_string(): + """Test cognification fails with whitespace-only string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(" \n\t ") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_none_data(): + """Test cognification fails with None data.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(None) + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_add_failure(): + """Test cognification handles cognee.add failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock), + ): + mock_add.side_effect = Exception("Add operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Add operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_cognify_failure(): + """Test cognification handles cognify failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock), + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + mock_cognify.side_effect = Exception("Cognify operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Cognify operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_re_raises_validation_error(): + """Test that CogneeValidationError is re-raised as-is.""" + with pytest.raises(CogneeValidationError): + await cognify_session("") + + +@pytest.mark.asyncio +async def test_cognify_session_with_special_characters(): + """Test cognification with special characters.""" + session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data, dataset_id="123") + + mock_add.assert_called_once_with( + session_data, dataset_id="123", node_set=["user_sessions_from_cache"] + ) + mock_cognify.assert_called_once() diff --git a/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py new file mode 100644 index 000000000..8cb27fef3 --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py @@ -0,0 +1,175 @@ +import sys +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from cognee.tasks.memify.extract_user_sessions import extract_user_sessions +from cognee.exceptions import CogneeSystemError +from cognee.modules.users.models import User + +# Get the actual module object (not the function) for patching +extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"] + + +@pytest.fixture +def mock_user(): + """Create a mock user.""" + user = MagicMock(spec=User) + user.id = "test-user-123" + return user + + +@pytest.fixture +def mock_qa_data(): + """Create mock Q&A data.""" + return [ + { + "question": "What is cognee?", + "context": "context about cognee", + "answer": "Cognee is a knowledge graph solution", + "time": "2025-01-01T12:00:00", + }, + { + "question": "How does it work?", + "context": "how it works context", + "answer": "It processes data and creates graphs", + "time": "2025-01-01T12:05:00", + }, + ] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_success(mock_user, mock_qa_data): + """Test successful extraction of sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + assert "Session ID: test_session" in sessions[0] + assert "Question: What is cognee?" in sessions[0] + assert "Answer: Cognee is a knowledge graph solution" in sessions[0] + assert "Question: How does it work?" in sessions[0] + assert "Answer: It processes data and creates graphs" in sessions[0] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data): + """Test extraction of multiple sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]): + sessions.append(session) + + assert len(sessions) == 2 + assert mock_cache_engine.get_all_qas.call_count == 2 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_data(mock_user, mock_qa_data): + """Test extraction handles empty data parameter.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions(None, session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_session_ids(mock_user): + """Test extraction handles no session IDs provided.""" + mock_cache_engine = AsyncMock() + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=None): + sessions.append(session) + + assert len(sessions) == 0 + mock_cache_engine.get_all_qas.assert_not_called() + + +@pytest.mark.asyncio +async def test_extract_user_sessions_empty_qa_data(mock_user): + """Test extraction handles empty Q&A data.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = [] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["empty_session"]): + sessions.append(session) + + assert len(sessions) == 0 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data): + """Test extraction continues on cache error for specific session.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.side_effect = [ + mock_qa_data, + Exception("Cache error"), + mock_qa_data, + ] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions( + [{}], session_ids=["session1", "session2", "session3"] + ): + sessions.append(session) + + assert len(sessions) == 2 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 7fcfe0d6b..206cfaf84 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -2,7 +2,6 @@ import os import pytest import pathlib from typing import Optional, Union -from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestGraphCompletionCoTRetriever: @pytest.mark.asyncio async def test_graph_completion_cot_context_simple(self): @@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) - - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_structured_completion("Who works at Figma?") - assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" - assert string_answer.strip(), "Answer should not be empty" - - # Test with structured response model - structured_answer = await retriever.get_structured_completion( - "Who works at Figma?", response_model=TestAnswer - ) - assert isinstance(structured_answer, TestAnswer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - assert structured_answer.answer.strip(), "Answer field should not be empty" - assert structured_answer.explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 252af8352..9bfed68f3 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -3,6 +3,7 @@ from typing import List import pytest import pathlib import cognee + from cognee.low_level import setup from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.vector import get_vector_engine diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py new file mode 100644 index 000000000..4ad3019ff --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -0,0 +1,204 @@ +import asyncio + +import pytest +import cognee +import pathlib +import os + +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +def _assert_string_answer(answer: list[str]): + assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" + assert all(item.strip() for item in answer), "Items should not be empty" + + +def _assert_structured_answer(answer: list[TestAnswer]): + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" + assert all(x.answer.strip() for x in answer), "Answer text should not be empty" + assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" + + +async def _test_get_structured_graph_completion_cot(): + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion(): + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_temporal(): + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_rag(): + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_context_extension(): + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_entity_completion(): + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +class TestStructuredOutputCompletion: + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index fc96081bf..5f4b93425 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -class TextSummariesRetriever: +class TestSummariesRetriever: @pytest.mark.asyncio async def test_chunk_context(self): system_directory_path = os.path.join( diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index a322cb237..c3c6a47f6 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,4 +1,3 @@ -import asyncio from types import SimpleNamespace import pytest diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..5eb6fb105 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,582 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == custom_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} diff --git a/cognee/tests/unit/modules/users/test_conditional_authentication.py b/cognee/tests/unit/modules/users/test_conditional_authentication.py index c4368d796..6568c3cb0 100644 --- a/cognee/tests/unit/modules/users/test_conditional_authentication.py +++ b/cognee/tests/unit/modules/users/test_conditional_authentication.py @@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration: # REQUIRE_AUTHENTICATION should be a boolean assert isinstance(REQUIRE_AUTHENTICATION, bool) - # Currently should be False (optional authentication) - assert not REQUIRE_AUTHENTICATION - class TestConditionalAuthenticationEnvironmentVariables: """Test environment variable handling.""" - def test_require_authentication_default_false(self): - """Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars.""" - with patch.dict(os.environ, {}, clear=True): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - module will see empty environment - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - importlib.invalidate_caches() - assert not REQUIRE_AUTHENTICATION - def test_require_authentication_true(self): """Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported.""" with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}): @@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables: assert REQUIRE_AUTHENTICATION - def test_require_authentication_false_explicit(self): - """Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported.""" - with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - module will see REQUIRE_AUTHENTICATION=false - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - assert not REQUIRE_AUTHENTICATION - - def test_require_authentication_case_insensitive(self): - """Test that environment variable parsing is case insensitive when imported.""" - test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"] - - for case in test_cases: - with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - expected = case.lower() == "true" - assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}" - - def test_current_require_authentication_value(self): - """Test that the current REQUIRE_AUTHENTICATION module value is as expected.""" - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - # The module-level variable should currently be False (set at import time) - assert isinstance(REQUIRE_AUTHENTICATION, bool) - assert not REQUIRE_AUTHENTICATION - class TestConditionalAuthenticationEdgeCases: """Test edge cases and error scenarios.""" diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py new file mode 100644 index 000000000..7d6a73a06 --- /dev/null +++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py @@ -0,0 +1,52 @@ +from itertools import product + +import numpy as np +import pytest + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.tasks.chunks import chunk_by_row + +INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA" +max_chunk_size_vals = [8, 32] + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_isomorphism(input_text, max_chunk_size): + chunks = chunk_by_row(input_text, max_chunk_size) + reconstructed_text = ", ".join([chunk["text"] for chunk in chunks]) + assert reconstructed_text == input_text, ( + f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_row_chunk_length(input_text, max_chunk_size): + chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)) + embedding_engine = get_embedding_engine() + + chunk_lengths = np.array( + [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] + assert np.all(chunk_lengths <= max_chunk_size), ( + f"{max_chunk_size = }: {larger_chunks} are too large" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size): + chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size) + chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) + assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( + f"{chunk_indices = } are not monotonically increasing" + ) diff --git a/docker-compose.yml b/docker-compose.yml index 43d9b2607..472f24c21 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: - DEBUG=false # Change to true if debugging - HOST=0.0.0.0 - ENVIRONMENT=local - - LOG_LEVEL=ERROR + - LOG_LEVEL=INFO extra_hosts: # Allows the container to reach your local machine using "host.docker.internal" instead of "localhost" - "host.docker.internal:host-gateway" diff --git a/entrypoint.sh b/entrypoint.sh index bad9b7aa3..496825408 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -43,10 +43,10 @@ sleep 2 if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then if [ "$DEBUG" = "true" ]; then echo "Waiting for the debugger to attach..." - debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app + exec debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app else - gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app + exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app fi else - gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app + exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error --access-logfile - --error-logfile - cognee.api.client:app fi diff --git a/examples/python/agentic_reasoning_procurement_example.py b/examples/python/agentic_reasoning_procurement_example.py index 5aa3caa70..4e9d2d7e4 100644 --- a/examples/python/agentic_reasoning_procurement_example.py +++ b/examples/python/agentic_reasoning_procurement_example.py @@ -168,7 +168,7 @@ async def run_procurement_example(): for q in questions: print(f"Question: \n{q}") results = await procurement_system.search_memory(q, search_categories=[category]) - top_answer = results[category][0] + top_answer = results[category][0]["search_result"][0] print(f"Answer: \n{top_answer.strip()}\n") research_notes[category].append({"question": q, "answer": top_answer}) diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py index 431069050..1b476a2c3 100644 --- a/examples/python/code_graph_example.py +++ b/examples/python/code_graph_example.py @@ -1,5 +1,7 @@ import argparse import asyncio +import os + import cognee from cognee import SearchType from cognee.shared.logging_utils import setup_logging, ERROR @@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline async def main(repo_path, include_docs): + # Disable permissions feature for this example + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + run_status = False async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs): run_status = run_status diff --git a/examples/python/conversation_session_persistence_example.py b/examples/python/conversation_session_persistence_example.py new file mode 100644 index 000000000..5346f5012 --- /dev/null +++ b/examples/python/conversation_session_persistence_example.py @@ -0,0 +1,98 @@ +import asyncio + +import cognee +from cognee import visualize_graph +from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, +) +from cognee.modules.search.types import SearchType +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger + +logger = get_logger("conversation_session_persistence_example") + + +async def main(): + # NOTE: CACHING has to be enabled for this example to work + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system" + text_2 = "Germany is a country located next to the Netherlands" + + await cognee.add([text_1, text_2]) + await cognee.cognify() + + question = "What can I use to create a knowledge graph?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "You sure about that?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "This is awesome!" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Where is Germany?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Next to which country again?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "So you remember everything I asked from you?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + session_ids_to_persist = ["default_session", "different_session"] + default_user = await get_default_user() + + await persist_sessions_in_knowledge_graph_pipeline( + user=default_user, + session_ids=session_ids_to_persist, + ) + + await visualize_graph() + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/examples/python/feedback_enrichment_minimal_example.py b/examples/python/feedback_enrichment_minimal_example.py index 11ef20830..8954bd5f6 100644 --- a/examples/python/feedback_enrichment_minimal_example.py +++ b/examples/python/feedback_enrichment_minimal_example.py @@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5): extraction_tasks=extraction_tasks, enrichment_tasks=enrichment_tasks, data=[{}], # A placeholder to prevent fetching the entire graph - dataset="feedback_enrichment_minimal", ) diff --git a/examples/python/memify_coding_agent_example.py b/examples/python/memify_coding_agent_example.py index 1fd3b1528..4a087ba61 100644 --- a/examples/python/memify_coding_agent_example.py +++ b/examples/python/memify_coding_agent_example.py @@ -89,7 +89,7 @@ async def main(): ) print("Coding rules created by memify:") - for coding_rule in coding_rules: + for coding_rule in coding_rules[0]["search_result"][0]: print("- " + coding_rule) # Visualize new graph with added memify context diff --git a/examples/python/permissions_example.py b/examples/python/permissions_example.py index 4f51b660f..c0b104023 100644 --- a/examples/python/permissions_example.py +++ b/examples/python/permissions_example.py @@ -3,6 +3,7 @@ import cognee import pathlib from cognee.modules.users.exceptions import PermissionDeniedError +from cognee.modules.users.tenants.methods import select_tenant from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.users.methods import create_user @@ -116,6 +117,7 @@ async def main(): print( "\nOperation started as user_2 to give read permission to user_1 for the dataset owned by user_2" ) + await authorized_give_permission_on_datasets( user_1.id, [quantum_dataset_id], @@ -142,6 +144,9 @@ async def main(): print("User 2 is creating CogneeLab tenant/organization") tenant_id = await create_tenant("CogneeLab", user_2.id) + print("User 2 is selecting CogneeLab tenant/organization as active tenant") + await select_tenant(user_id=user_2.id, tenant_id=tenant_id) + print("\nUser 2 is creating Researcher role") role_id = await create_role(role_name="Researcher", owner_id=user_2.id) @@ -157,23 +162,59 @@ async def main(): ) await add_user_to_role(user_id=user_3.id, role_id=role_id, owner_id=user_2.id) + print("\nOperation as user_3 to select CogneeLab tenant/organization as active tenant") + await select_tenant(user_id=user_3.id, tenant_id=tenant_id) + print( - "\nOperation started as user_2 to give read permission to Researcher role for the dataset owned by user_2" + "\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by user_2" + ) + # Even though the dataset owner is user_2, the dataset doesn't belong to the tenant/organization CogneeLab. + # So we can't assign permissions to it when we're acting in the CogneeLab tenant. + try: + await authorized_give_permission_on_datasets( + role_id, + [quantum_dataset_id], + "read", + user_2.id, + ) + except PermissionDeniedError: + print( + "User 2 could not give permission to the role as the QUANTUM dataset is not part of the CogneeLab tenant" + ) + + print( + "We will now create a new QUANTUM dataset with the QUANTUM_COGNEE_LAB name in the CogneeLab tenant so that permissions can be assigned to the Researcher role inside the tenant/organization" + ) + # We can re-create the QUANTUM dataset in the CogneeLab tenant. The old QUANTUM dataset is still owned by user_2 personally + # and can still be accessed by selecting the personal tenant for user 2. + from cognee.modules.users.methods import get_user + + # Note: We need to update user_2 from the database to refresh its tenant context changes + user_2 = await get_user(user_2.id) + await cognee.add([text], dataset_name="QUANTUM_COGNEE_LAB", user=user_2) + quantum_cognee_lab_cognify_result = await cognee.cognify(["QUANTUM_COGNEE_LAB"], user=user_2) + + # The recreated Quantum dataset will now have a different dataset_id as it's a new dataset in a different organization + quantum_cognee_lab_dataset_id = extract_dataset_id_from_cognify( + quantum_cognee_lab_cognify_result + ) + print( + "\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by the CogneeLab tenant" ) await authorized_give_permission_on_datasets( role_id, - [quantum_dataset_id], + [quantum_cognee_lab_dataset_id], "read", user_2.id, ) # Now user_3 can read from QUANTUM dataset as part of the Researcher role after proper permissions have been assigned by the QUANTUM dataset owner, user_2. - print("\nSearch result as user_3 on the dataset owned by user_2:") + print("\nSearch result as user_3 on the QUANTUM dataset owned by the CogneeLab organization:") search_results = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="What is in the document?", - user=user_1, - dataset_ids=[quantum_dataset_id], + user=user_3, + dataset_ids=[quantum_cognee_lab_dataset_id], ) for result in search_results: print(f"{result}\n") diff --git a/examples/python/relational_database_migration_example.py b/examples/python/relational_database_migration_example.py index 7e87347bc..98482cb4b 100644 --- a/examples/python/relational_database_migration_example.py +++ b/examples/python/relational_database_migration_example.py @@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import ( async def main(): + # Disable backend access control to migrate relational data + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + # Clean all data stored in Cognee await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/examples/python/run_custom_pipeline_example.py b/examples/python/run_custom_pipeline_example.py new file mode 100644 index 000000000..1ca1b4402 --- /dev/null +++ b/examples/python/run_custom_pipeline_example.py @@ -0,0 +1,84 @@ +import asyncio +import cognee +from cognee.modules.engine.operations.setup import setup +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import setup_logging, INFO +from cognee.modules.pipelines import Task +from cognee.api.v1.search import SearchType + +# Prerequisites: +# 1. Copy `.env.template` and rename it to `.env`. +# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field: +# LLM_API_KEY = "your_key_here" + + +async def main(): + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # Create relational database and tables + await setup() + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + + print("Adding text to cognee:") + print(text.strip()) + + # Let's recreate the cognee add pipeline through the custom pipeline framework + from cognee.tasks.ingestion import ingest_data, resolve_data_directories + + user = await get_default_user() + + # Values for tasks need to be filled before calling the pipeline + add_tasks = [ + Task(resolve_data_directories, include_subdirectories=True), + Task( + ingest_data, + "main_dataset", + user, + ), + ] + # Forward tasks to custom pipeline along with data and user information + await cognee.run_custom_pipeline( + tasks=add_tasks, data=text, user=user, dataset="main_dataset", pipeline_name="add_pipeline" + ) + print("Text added successfully.\n") + + # Use LLMs and cognee to create knowledge graph + from cognee.api.v1.cognify.cognify import get_default_tasks + + cognify_tasks = await get_default_tasks(user=user) + print("Recreating existing cognify pipeline in custom pipeline to create knowledge graph...\n") + await cognee.run_custom_pipeline( + tasks=cognify_tasks, user=user, dataset="main_dataset", pipeline_name="cognify_pipeline" + ) + print("Cognify process complete.\n") + + query_text = "Tell me about NLP" + print(f"Searching cognee for insights with query: '{query_text}'") + # Query cognee for insights on the added text + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query_text + ) + + print("Search results:") + # Display results + for result_text in search_results: + print(result_text) + + +if __name__ == "__main__": + logger = setup_logging(log_level=INFO) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/examples/python/simple_example.py b/examples/python/simple_example.py index c13e48f85..237a8295e 100644 --- a/examples/python/simple_example.py +++ b/examples/python/simple_example.py @@ -59,14 +59,6 @@ async def main(): for result_text in search_results: print(result_text) - # Example output: - # ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'}) - # (...) - # It represents nodes and relationships in the knowledge graph: - # - The first element is the source node (e.g., 'natural language processing'). - # - The second element is the relationship between nodes (e.g., 'is_a_subfield_of'). - # - The third element is the target node (e.g., 'computer science'). - if __name__ == "__main__": logger = setup_logging(log_level=ERROR) diff --git a/pyproject.toml b/pyproject.toml index 3af0599f1..81a5da471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "cognee" -version = "0.4.1" +version = "0.5.0.dev0" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, @@ -58,6 +58,8 @@ dependencies = [ "websockets>=15.0.1,<16.0.0", "mistralai>=1.9.10", "tenacity>=9.0.0", + "fakeredis[lua]>=2.32.0", + "diskcache>=5.6.3", ] [project.optional-dependencies] @@ -155,7 +157,6 @@ Homepage = "https://www.cognee.ai" Repository = "https://github.com/topoteretes/cognee" [project.scripts] -cognee = "cognee.cli._cognee:main" cognee-cli = "cognee.cli._cognee:main" [build-system] @@ -168,7 +169,6 @@ exclude = [ "/dist", "/.data", "/.github", - "/alembic", "/deployment", "/cognee-mcp", "/cognee-frontend", @@ -200,3 +200,8 @@ exclude = [ [tool.ruff.lint] ignore = ["F401"] + +[dependency-groups] +dev = [ + "pytest-timeout>=2.4.0", +]