merge dev
This commit is contained in:
commit
f7e072f533
177 changed files with 6319 additions and 762 deletions
|
|
@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
|
|||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
LLM_MAX_TOKENS="16384"
|
||||
# Instructor's modes determine how structured data is requested from and extracted from LLM responses
|
||||
# You can change this type (i.e. mode) via this env variable
|
||||
# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
|
||||
LLM_INSTRUCTOR_MODE=""
|
||||
|
||||
EMBEDDING_PROVIDER="openai"
|
||||
EMBEDDING_MODEL="openai/text-embedding-3-large"
|
||||
|
|
@ -169,8 +173,9 @@ REQUIRE_AUTHENTICATION=False
|
|||
# Vector: LanceDB
|
||||
# Graph: KuzuDB
|
||||
#
|
||||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
|
||||
# Disable mode when using not supported graph/vector databases.
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=True
|
||||
|
||||
################################################################################
|
||||
# ☁️ Cloud Sync Settings
|
||||
|
|
|
|||
5
.github/actions/cognee_setup/action.yml
vendored
5
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -42,3 +42,8 @@ runs:
|
|||
done
|
||||
fi
|
||||
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS
|
||||
|
||||
- name: Add telemetry identifier for telemetry test and in case telemetry is enabled by accident
|
||||
shell: bash
|
||||
run: |
|
||||
echo "test-machine" > .anon_id
|
||||
|
|
|
|||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
|
|
@ -6,6 +6,14 @@ Please provide a clear, human-generated description of the changes in this PR.
|
|||
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
|
||||
-->
|
||||
|
||||
## Acceptance Criteria
|
||||
<!--
|
||||
* Key requirements to the new feature or modification;
|
||||
* Proof that the changes work and meet the requirements;
|
||||
* Include instructions on how to verify the changes. Describe how to test it locally;
|
||||
* Proof that it's sufficiently tested.
|
||||
-->
|
||||
|
||||
## Type of Change
|
||||
<!-- Please check the relevant option -->
|
||||
- [ ] Bug fix (non-breaking change that fixes an issue)
|
||||
|
|
|
|||
5
.github/workflows/basic_tests.yml
vendored
5
.github/workflows/basic_tests.yml
vendored
|
|
@ -75,6 +75,7 @@ jobs:
|
|||
name: Run Unit Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -104,6 +105,7 @@ jobs:
|
|||
name: Run Integration Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -132,6 +134,7 @@ jobs:
|
|||
name: Run Simple Examples
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -161,6 +164,7 @@ jobs:
|
|||
name: Run Simple Examples BAML
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
||||
BAML_LLM_PROVIDER: openai
|
||||
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
|
|
@ -198,6 +202,7 @@ jobs:
|
|||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
|
|||
3
.github/workflows/cli_tests.yml
vendored
3
.github/workflows/cli_tests.yml
vendored
|
|
@ -39,6 +39,7 @@ jobs:
|
|||
name: CLI Unit Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -66,6 +67,7 @@ jobs:
|
|||
name: CLI Integration Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -93,6 +95,7 @@ jobs:
|
|||
name: CLI Functionality Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
|
|||
6
.github/workflows/db_examples_tests.yml
vendored
6
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -60,7 +60,7 @@ jobs:
|
|||
|
||||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: dev
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -95,7 +95,7 @@ jobs:
|
|||
|
||||
- name: Run Kuzu Example
|
||||
env:
|
||||
ENV: dev
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -141,7 +141,7 @@ jobs:
|
|||
|
||||
- name: Run PGVector Example
|
||||
env:
|
||||
ENV: dev
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
90
.github/workflows/e2e_tests.yml
vendored
90
.github/workflows/e2e_tests.yml
vendored
|
|
@ -226,7 +226,7 @@ jobs:
|
|||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run parallel databases test
|
||||
- name: Run permissions test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -239,6 +239,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_permissions.py
|
||||
|
||||
test-multi-tenancy:
|
||||
name: Test multi tenancy with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run multi tenancy test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_multi_tenancy.py
|
||||
|
||||
test-graph-edges:
|
||||
name: Test graph edge ingestion
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -308,7 +333,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres redis"
|
||||
|
||||
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres)
|
||||
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres/Redis)
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -321,6 +346,7 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'redis'
|
||||
SHARED_KUZU_LOCK: true
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
|
|
@ -386,8 +412,8 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||
|
||||
run_conversation_sessions_test:
|
||||
name: Conversation sessions test
|
||||
run_conversation_sessions_test_redis:
|
||||
name: Conversation sessions test (Redis)
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
|
|
@ -427,7 +453,60 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres redis"
|
||||
|
||||
- name: Run Conversation session tests
|
||||
- name: Run Conversation session tests (Redis)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'redis'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
DB_PORT: 5432
|
||||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
run_conversation_sessions_test_fs:
|
||||
name: Conversation sessions test (FS)
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
POSTGRES_USER: cognee
|
||||
POSTGRES_PASSWORD: cognee
|
||||
POSTGRES_DB: cognee_db
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Run Conversation session tests (FS)
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -440,6 +519,7 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'fs'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
|
|||
36
.github/workflows/examples_tests.yml
vendored
36
.github/workflows/examples_tests.yml
vendored
|
|
@ -21,6 +21,7 @@ jobs:
|
|||
|
||||
- name: Run Multimedia Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: uv run python ./examples/python/multimedia_example.py
|
||||
|
|
@ -40,6 +41,7 @@ jobs:
|
|||
|
||||
- name: Run Evaluation Framework Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -69,6 +71,7 @@ jobs:
|
|||
|
||||
- name: Run Descriptive Graph Metrics Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -99,6 +102,7 @@ jobs:
|
|||
|
||||
- name: Run Dynamic Steps Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -124,6 +128,7 @@ jobs:
|
|||
|
||||
- name: Run Temporal Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -149,6 +154,7 @@ jobs:
|
|||
|
||||
- name: Run Ontology Demo Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -174,6 +180,7 @@ jobs:
|
|||
|
||||
- name: Run Agentic Reasoning Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -199,6 +206,7 @@ jobs:
|
|||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -210,6 +218,32 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/memify_coding_agent_example.py
|
||||
|
||||
test-custom-pipeline:
|
||||
name: Run Custom Pipeline Example
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Custom Pipeline Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/run_custom_pipeline_example.py
|
||||
|
||||
test-permissions-example:
|
||||
name: Run Permissions Example
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -224,6 +258,7 @@ jobs:
|
|||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -249,6 +284,7 @@ jobs:
|
|||
|
||||
- name: Run Docling Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
|
|||
70
.github/workflows/load_tests.yml
vendored
Normal file
70
.github/workflows/load_tests.yml
vendored
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
name: Load tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
AWS_ACCESS_KEY_ID:
|
||||
required: true
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
test-load:
|
||||
name: Test Load
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Verify File Descriptor Limit
|
||||
run: ulimit -n
|
||||
|
||||
- name: Run Load Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: True
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
STORAGE_BACKEND: s3
|
||||
AWS_REGION: eu-west-1
|
||||
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
|
||||
run: uv run python ./cognee/tests/test_load.py
|
||||
|
||||
|
||||
17
.github/workflows/release_test.yml
vendored
Normal file
17
.github/workflows/release_test.yml
vendored
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Long-running, heavy and resource-consuming tests for release validation
|
||||
name: Release Test Workflow
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
load-tests:
|
||||
name: Load Tests
|
||||
uses: ./.github/workflows/load_tests.yml
|
||||
secrets: inherit
|
||||
3
.github/workflows/search_db_tests.yml
vendored
3
.github/workflows/search_db_tests.yml
vendored
|
|
@ -84,6 +84,7 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
|
@ -135,6 +136,7 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
@ -197,6 +199,7 @@ jobs:
|
|||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ on:
|
|||
required: false
|
||||
type: string
|
||||
default: '["3.10.x", "3.12.x", "3.13.x"]'
|
||||
os:
|
||||
required: false
|
||||
type: string
|
||||
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
|
||||
secrets:
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
|
|
@ -40,10 +44,11 @@ jobs:
|
|||
run-unit-tests:
|
||||
name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -76,10 +81,11 @@ jobs:
|
|||
run-integration-tests:
|
||||
name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -112,10 +118,11 @@ jobs:
|
|||
run-library-test:
|
||||
name: Library test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -148,10 +155,11 @@ jobs:
|
|||
run-build-test:
|
||||
name: Build test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -177,10 +185,11 @@ jobs:
|
|||
run-soft-deletion-test:
|
||||
name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -214,10 +223,11 @@ jobs:
|
|||
run-hard-deletion-test:
|
||||
name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
25
.github/workflows/test_suites.yml
vendored
25
.github/workflows/test_suites.yml
vendored
|
|
@ -1,4 +1,6 @@
|
|||
name: Test Suites
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -80,12 +82,22 @@ jobs:
|
|||
uses: ./.github/workflows/notebooks_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
different-operating-systems-tests:
|
||||
name: Operating System and Python Tests
|
||||
different-os-tests-basic:
|
||||
name: OS and Python Tests Ubuntu
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
|
||||
os: '["ubuntu-22.04"]'
|
||||
secrets: inherit
|
||||
|
||||
different-os-tests-extended:
|
||||
name: OS and Python Tests Extended
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.13.x"]'
|
||||
os: '["macos-15", "windows-latest"]'
|
||||
secrets: inherit
|
||||
|
||||
# Matrix-based vector database tests
|
||||
|
|
@ -135,7 +147,8 @@ jobs:
|
|||
e2e-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
llm-tests,
|
||||
|
|
@ -155,7 +168,8 @@ jobs:
|
|||
cli-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
db-examples-tests,
|
||||
|
|
@ -176,7 +190,8 @@ jobs:
|
|||
"${{ needs.cli-tests.result }}" == "success" &&
|
||||
"${{ needs.graph-db-tests.result }}" == "success" &&
|
||||
"${{ needs.notebook-tests.result }}" == "success" &&
|
||||
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-basic.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-extended.result }}" == "success" &&
|
||||
"${{ needs.vector-db-tests.result }}" == "success" &&
|
||||
"${{ needs.example-tests.result }}" == "success" &&
|
||||
"${{ needs.db-examples-tests.result }}" == "success" &&
|
||||
|
|
|
|||
33
.github/workflows/weighted_edges_tests.yml
vendored
33
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -2,7 +2,7 @@ name: Weighted Edges Tests
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, weighted_edges ]
|
||||
branches: [ main, dev, weighted_edges ]
|
||||
paths:
|
||||
- 'cognee/modules/graph/utils/get_graph_from_model.py'
|
||||
- 'cognee/infrastructure/engine/models/Edge.py'
|
||||
|
|
@ -10,7 +10,7 @@ on:
|
|||
- 'examples/python/weighted_edges_example.py'
|
||||
- '.github/workflows/weighted_edges_tests.yml'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [ main, dev ]
|
||||
paths:
|
||||
- 'cognee/modules/graph/utils/get_graph_from_model.py'
|
||||
- 'cognee/infrastructure/engine/models/Edge.py'
|
||||
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
|
|
@ -67,14 +67,13 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_ENDPOINT: https://api.openai.com/v1/
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_ENDPOINT: https://api.openai.com/v1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: text-embedding-3-small
|
||||
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
|
||||
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
EMBEDDING_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -108,14 +107,14 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_ENDPOINT: https://api.openai.com/v1/
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_ENDPOINT: https://api.openai.com/v1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: text-embedding-3-small
|
||||
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
|
||||
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
EMBEDDING_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
|||
|
|
@ -87,11 +87,6 @@ db_engine = get_relational_engine()
|
|||
|
||||
print("Using database:", db_engine.db_uri)
|
||||
|
||||
if "sqlite" in db_engine.db_uri:
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
|
||||
run_sync(db_engine.create_database())
|
||||
|
||||
config.set_section_option(
|
||||
config.config_ini_section,
|
||||
"SQLALCHEMY_DATABASE_URI",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Sequence, Union
|
|||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -26,7 +27,34 @@ def upgrade() -> None:
|
|||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus_enum = postgresql.ENUM(
|
||||
"STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
|
||||
)
|
||||
syncstatus_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
if "sync_operations" not in inspector.get_table_names():
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus = postgresql.ENUM(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
else:
|
||||
syncstatus = sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
# Table doesn't exist, create it normally
|
||||
op.create_table(
|
||||
"sync_operations",
|
||||
|
|
@ -34,15 +62,7 @@ def upgrade() -> None:
|
|||
sa.Column("run_id", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
),
|
||||
syncstatus,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("progress_percentage", sa.Integer(), nullable=True),
|
||||
|
|
|
|||
|
|
@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
try:
|
||||
await_only(create_default_user())
|
||||
except UserAlreadyExists:
|
||||
pass # It's fine if the default user already exists
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
await_only(delete_user("default_user@example.com"))
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
"""Expand dataset database for multi user
|
||||
|
||||
Revision ID: 76625596c5c3
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-10-30 12:55:20.239562
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "76625596c5c3"
|
||||
down_revision: Union[str, None] = "c946955da633"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
vector_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_provider"
|
||||
)
|
||||
if not vector_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_provider"
|
||||
)
|
||||
if not graph_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
|
||||
if not vector_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
|
||||
if not graph_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
|
||||
if not vector_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
|
||||
if not graph_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("dataset_database", "vector_database_provider")
|
||||
op.drop_column("dataset_database", "graph_database_provider")
|
||||
op.drop_column("dataset_database", "vector_database_url")
|
||||
op.drop_column("dataset_database", "graph_database_url")
|
||||
op.drop_column("dataset_database", "vector_database_key")
|
||||
op.drop_column("dataset_database", "graph_database_key")
|
||||
|
|
@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
# we might want to delete this
|
||||
await_only(db_engine.create_database())
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
await_only(db_engine.delete_database())
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
|
|||
)
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
dataset_id_column = _get_column(insp, "acls", "dataset_id")
|
||||
if not dataset_id_column:
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
|
||||
if not datasets:
|
||||
return
|
||||
if not datasets:
|
||||
return
|
||||
|
||||
acl_list = []
|
||||
acl_list = []
|
||||
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(
|
||||
_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
|
||||
)
|
||||
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
|
|||
137
alembic/versions/c946955da633_multi_tenant_support.py
Normal file
137
alembic/versions/c946955da633_multi_tenant_support.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Multi Tenant Support
|
||||
|
||||
Revision ID: c946955da633
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-11-04 18:11:09.325158
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c946955da633"
|
||||
down_revision: Union[str, None] = "211ab850ef3d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _define_user_table() -> sa.Table:
|
||||
table = sa.Table(
|
||||
"users",
|
||||
sa.MetaData(),
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.UUID,
|
||||
sa.ForeignKey("principals.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _define_dataset_table() -> sa.Table:
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
table = sa.Table(
|
||||
"datasets",
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
|
||||
sa.Column("name", sa.Text),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
|
||||
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
dataset = _define_dataset_table()
|
||||
user = _define_user_table()
|
||||
|
||||
if "user_tenants" not in insp.get_table_names():
|
||||
# Define table with all necessary columns including primary key
|
||||
user_tenants = op.create_table(
|
||||
"user_tenants",
|
||||
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
|
||||
# Get all users with their tenant_id
|
||||
user_data = conn.execute(
|
||||
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# Insert into user_tenants table
|
||||
if user_data:
|
||||
op.bulk_insert(
|
||||
user_tenants,
|
||||
[
|
||||
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
|
||||
for user_id, tenant_id in user_data
|
||||
],
|
||||
)
|
||||
|
||||
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
|
||||
if not tenant_id_column:
|
||||
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
|
||||
|
||||
# Build subquery, select users.tenant_id for each dataset.owner_id
|
||||
tenant_id_from_dataset_owner = (
|
||||
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
|
||||
)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
# If column doesn't exist create new original_extension column and update from values of extension column
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.execute(
|
||||
dataset.update().values(
|
||||
tenant_id=tenant_id_from_dataset_owner,
|
||||
)
|
||||
)
|
||||
else:
|
||||
conn = op.get_bind()
|
||||
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
|
||||
|
||||
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user_tenants")
|
||||
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
|
||||
op.drop_column("datasets", "tenant_id")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -194,7 +194,6 @@ async def cognify(
|
|||
|
||||
Prerequisites:
|
||||
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
|
||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
||||
- **Vector Database**: Must be accessible for embeddings storage
|
||||
- **Graph Database**: Must be accessible for relationship storage
|
||||
|
||||
|
|
@ -1096,6 +1095,10 @@ async def main():
|
|||
|
||||
# Skip migrations when in API mode (the API server handles its own database)
|
||||
if not args.no_migration and not args.api_url:
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
await setup()
|
||||
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .api.v1.add import add
|
|||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .modules.memify import memify
|
||||
from .modules.run_custom_pipeline import run_custom_pipeline
|
||||
from .api.v1.update import update
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
|
|||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
|
|
@ -39,6 +40,8 @@ from cognee.api.v1.users.routers import (
|
|||
)
|
||||
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
|
||||
|
||||
# Ensure application logging is configured for container stdout/stderr
|
||||
setup_logging()
|
||||
logger = get_logger()
|
||||
|
||||
if os.getenv("ENV", "prod") == "prod":
|
||||
|
|
@ -74,6 +77,9 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
await get_default_user()
|
||||
|
||||
# Emit a clear startup message for docker logs
|
||||
logger.info("Backend server has started")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
|
@ -258,6 +264,8 @@ app.include_router(
|
|||
|
||||
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
|
||||
|
||||
app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
|
||||
|
||||
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
|
||||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
|
|
|||
|
|
@ -82,7 +82,9 @@ def get_add_router() -> APIRouter:
|
|||
datasetName,
|
||||
user=user,
|
||||
dataset_id=datasetId,
|
||||
node_set=node_set if node_set else None,
|
||||
node_set=node_set
|
||||
if node_set != [""]
|
||||
else None, # Transform default node_set endpoint value to None
|
||||
)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
|
|
|
|||
|
|
@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
|
|||
custom_prompt: Optional[str] = Field(
|
||||
default="", description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
ontology_key: Optional[List[str]] = Field(
|
||||
default=None, description="Reference to one or more previously uploaded ontologies"
|
||||
)
|
||||
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
|
|
@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
|
|||
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
|
||||
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
|
||||
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
|
||||
- **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
|
||||
|
||||
## Response
|
||||
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
|
||||
|
|
@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
|
|||
{
|
||||
"datasets": ["research_papers", "documentation"],
|
||||
"run_in_background": false,
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
|
||||
"ontology_key": ["medical_ontology_v1"]
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
|
|||
)
|
||||
|
||||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||
from cognee.api.v1.ontologies.ontologies import OntologyService
|
||||
|
||||
try:
|
||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
config_to_use = None
|
||||
|
||||
if payload.ontology_key:
|
||||
ontology_service = OntologyService()
|
||||
ontology_contents = ontology_service.get_ontology_contents(
|
||||
payload.ontology_key, user
|
||||
)
|
||||
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
|
||||
RDFLibOntologyResolver,
|
||||
)
|
||||
from io import StringIO
|
||||
|
||||
ontology_streams = [StringIO(content) for content in ontology_contents]
|
||||
config_to_use: Config = {
|
||||
"ontology_config": {
|
||||
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
|
||||
}
|
||||
}
|
||||
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets,
|
||||
user,
|
||||
config=config_to_use,
|
||||
run_in_background=payload.run_in_background,
|
||||
custom_prompt=payload.custom_prompt,
|
||||
)
|
||||
|
|
|
|||
4
cognee/api/v1/ontologies/__init__.py
Normal file
4
cognee/api/v1/ontologies/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .ontologies import OntologyService
|
||||
from .routers.get_ontology_router import get_ontology_router
|
||||
|
||||
__all__ = ["OntologyService", "get_ontology_router"]
|
||||
183
cognee/api/v1/ontologies/ontologies.py
Normal file
183
cognee/api/v1/ontologies/ontologies.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyMetadata:
|
||||
ontology_key: str
|
||||
filename: str
|
||||
size_bytes: int
|
||||
uploaded_at: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class OntologyService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return Path(tempfile.gettempdir()) / "ontologies"
|
||||
|
||||
def _get_user_dir(self, user_id: str) -> Path:
|
||||
user_dir = self.base_dir / str(user_id)
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir
|
||||
|
||||
def _get_metadata_path(self, user_dir: Path) -> Path:
|
||||
return user_dir / "metadata.json"
|
||||
|
||||
def _load_metadata(self, user_dir: Path) -> dict:
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, user_dir: Path, metadata: dict):
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
async def upload_ontology(
|
||||
self, ontology_key: str, file, user, description: Optional[str] = None
|
||||
) -> OntologyMetadata:
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError("File must be in .owl format")
|
||||
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
if ontology_key in metadata:
|
||||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError("File size exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{ontology_key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": description,
|
||||
}
|
||||
metadata[ontology_key] = ontology_metadata
|
||||
self._save_metadata(user_dir, metadata)
|
||||
|
||||
return OntologyMetadata(
|
||||
ontology_key=ontology_key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
description=description,
|
||||
)
|
||||
|
||||
async def upload_ontologies(
|
||||
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
|
||||
) -> List[OntologyMetadata]:
|
||||
"""
|
||||
Upload ontology files with their respective keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of unique keys for each ontology
|
||||
files: List of UploadFile objects (same length as keys)
|
||||
user: Authenticated user
|
||||
descriptions: Optional list of descriptions for each file
|
||||
|
||||
Returns:
|
||||
List of OntologyMetadata objects for uploaded files
|
||||
|
||||
Raises:
|
||||
ValueError: If keys duplicate, file format invalid, or array lengths don't match
|
||||
"""
|
||||
if len(ontology_key) != len(files):
|
||||
raise ValueError("Number of keys must match number of files")
|
||||
|
||||
if len(set(ontology_key)) != len(ontology_key):
|
||||
raise ValueError("Duplicate ontology keys not allowed")
|
||||
|
||||
if descriptions and len(descriptions) != len(files):
|
||||
raise ValueError("Number of descriptions must match number of files")
|
||||
|
||||
results = []
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||
if key in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' already exists")
|
||||
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError(f"File '{file.filename}' must be in .owl format")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": descriptions[i] if descriptions else None,
|
||||
}
|
||||
metadata[key] = ontology_metadata
|
||||
|
||||
results.append(
|
||||
OntologyMetadata(
|
||||
ontology_key=key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
description=descriptions[i] if descriptions else None,
|
||||
)
|
||||
)
|
||||
|
||||
self._save_metadata(user_dir, metadata)
|
||||
return results
|
||||
|
||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||
"""
|
||||
Retrieve ontology content for one or more keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of ontology keys to retrieve (can contain single item)
|
||||
user: Authenticated user
|
||||
|
||||
Returns:
|
||||
List of ontology content strings
|
||||
|
||||
Raises:
|
||||
ValueError: If any ontology key not found
|
||||
"""
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
contents = []
|
||||
for key in ontology_key:
|
||||
if key not in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' not found")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"Ontology file for key '{key}' not found")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
contents.append(f.read())
|
||||
return contents
|
||||
|
||||
def list_ontologies(self, user) -> dict:
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
return self._load_metadata(user_dir)
|
||||
0
cognee/api/v1/ontologies/routers/__init__.py
Normal file
0
cognee/api/v1/ontologies/routers/__init__.py
Normal file
107
cognee/api/v1/ontologies/routers/get_ontology_router.py
Normal file
107
cognee/api/v1/ontologies/routers/get_ontology_router.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, List
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
from ..ontologies import OntologyService
|
||||
|
||||
|
||||
def get_ontology_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
ontology_service = OntologyService()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
async def upload_ontology(
|
||||
ontology_key: str = Form(...),
|
||||
ontology_file: List[UploadFile] = File(...),
|
||||
descriptions: Optional[str] = Form(None),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Upload ontology files with their respective keys for later use in cognify operations.
|
||||
|
||||
Supports both single and multiple file uploads:
|
||||
- Single file: ontology_key=["key"], ontology_file=[file]
|
||||
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
|
||||
|
||||
## Request Parameters
|
||||
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
|
||||
- **ontology_file** (List[UploadFile]): OWL format ontology files
|
||||
- **descriptions** (Optional[str]): JSON array string of optional descriptions
|
||||
|
||||
## Response
|
||||
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology Upload API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
ontology_keys = json.loads(ontology_key)
|
||||
description_list = json.loads(descriptions) if descriptions else None
|
||||
|
||||
if not isinstance(ontology_keys, list):
|
||||
raise ValueError("ontology_key must be a JSON array")
|
||||
|
||||
results = await ontology_service.upload_ontologies(
|
||||
ontology_keys, ontology_file, user, description_list
|
||||
)
|
||||
|
||||
return {
|
||||
"uploaded_ontologies": [
|
||||
{
|
||||
"ontology_key": result.ontology_key,
|
||||
"filename": result.filename,
|
||||
"size_bytes": result.size_bytes,
|
||||
"uploaded_at": result.uploaded_at,
|
||||
"description": result.description,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
@router.get("", response_model=dict)
|
||||
async def list_ontologies(user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
List all uploaded ontologies for the authenticated user.
|
||||
|
||||
## Response
|
||||
Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
|
||||
|
||||
## Error Codes
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology List API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
metadata = ontology_service.list_ontologies(user)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
return router
|
||||
|
|
@ -1,15 +1,20 @@
|
|||
from uuid import UUID
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
|
||||
class SelectTenantDTO(InDTO):
|
||||
tenant_id: UUID | None = None
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
permissions_router = APIRouter()
|
||||
|
||||
|
|
@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter:
|
|||
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
|
||||
)
|
||||
|
||||
@permissions_router.post("/tenants/select")
|
||||
async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Select current tenant.
|
||||
|
||||
This endpoint selects a tenant with the specified UUID. Tenants are used
|
||||
to organize users and resources in multi-tenant environments, providing
|
||||
isolation and access control between different groups or organizations.
|
||||
|
||||
Sending a null/None value as tenant_id selects his default single user tenant
|
||||
|
||||
## Request Parameters
|
||||
- **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
|
||||
|
||||
## Response
|
||||
Returns a success message along with selected tenant id.
|
||||
"""
|
||||
send_telemetry(
|
||||
"Permissions API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
|
||||
"tenant_id": str(payload.tenant_id),
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
|
||||
|
||||
await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
|
||||
)
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -200,6 +202,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
|
|
@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
chunker_class = LangchainChunker
|
||||
except ImportError:
|
||||
fmt.warning("LangchainChunker not available, using TextChunker")
|
||||
elif args.chunker == "CsvChunker":
|
||||
try:
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
|
||||
chunker_class = CsvChunker
|
||||
except ImportError:
|
||||
fmt.warning("CsvChunker not available, using TextChunker")
|
||||
|
||||
result = await cognee.cognify(
|
||||
datasets=datasets,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
|
|||
]
|
||||
|
||||
# Chunker choices
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
|
||||
|
||||
# Output format choices
|
||||
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
|
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
)
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
@ -38,9 +69,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
"""
|
||||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if not backend_access_control_enabled():
|
||||
return
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
|
@ -48,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||
)
|
||||
|
|
@ -57,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
vector_config = {
|
||||
"vector_db_url": os.path.join(
|
||||
databases_directory_path, dataset_database.vector_database_name
|
||||
),
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
"vector_db_key": dataset_database.vector_database_key,
|
||||
"vector_db_name": dataset_database.vector_database_name,
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"graph_database_provider": "kuzu",
|
||||
"graph_database_provider": dataset_database.graph_database_provider,
|
||||
"graph_database_url": dataset_database.graph_database_url,
|
||||
"graph_database_name": dataset_database.graph_database_name,
|
||||
"graph_database_key": dataset_database.graph_database_key,
|
||||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
|
|
|
|||
29
cognee/eval_framework/Dockerfile
Normal file
29
cognee/eval_framework/Dockerfile
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml poetry.lock README.md /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
|
|
@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
|
|||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
############
|
||||
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
|
||||
if isinstance(retrieval_context, list):
|
||||
retrieval_context = await retriever.convert_retrieved_objects_to_context(
|
||||
triplets=retrieval_context
|
||||
)
|
||||
|
||||
if isinstance(search_results, str):
|
||||
search_results = [search_results]
|
||||
#############
|
||||
answer = {
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logger.info("Question answering started...")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
|
|
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
|
|||
"EM",
|
||||
"f1",
|
||||
] # Use only 'correctness' for DirectLLM
|
||||
deepeval_model: str = "gpt-5-mini"
|
||||
deepeval_model: str = "gpt-4o-mini"
|
||||
|
||||
# Metrics params
|
||||
calculate_metrics: bool = True
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import modal
|
|||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
|
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
|
|||
from cognee.eval_framework.answer_generation.run_question_answering_module import (
|
||||
run_question_answering,
|
||||
)
|
||||
import pathlib
|
||||
from os import path
|
||||
from modal import Image
|
||||
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
|
||||
from cognee.eval_framework.metrics_dashboard import create_dashboard
|
||||
|
||||
|
|
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
|
|||
|
||||
app = modal.App("modal-run-eval")
|
||||
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.copy_local_file("pyproject.toml", "pyproject.toml")
|
||||
.copy_local_file("poetry.lock", "poetry.lock")
|
||||
.env(
|
||||
{
|
||||
"ENV": os.getenv("ENV"),
|
||||
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
)
|
||||
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
|
||||
image = Image.from_dockerfile(
|
||||
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
|
||||
force_build=False,
|
||||
).add_local_python_source("cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
max_containers=10,
|
||||
timeout=86400,
|
||||
volumes={"/data": vol},
|
||||
secrets=[modal.Secret.from_name("eval_secrets")],
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
|
||||
async def modal_run_eval(eval_params=None):
|
||||
"""Runs evaluation pipeline and returns combined metrics results."""
|
||||
if eval_params is None:
|
||||
|
|
@ -105,18 +104,7 @@ async def main():
|
|||
configs = [
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="HotPotQA",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
answering_questions=True,
|
||||
evaluating_answers=True,
|
||||
calculate_metrics=True,
|
||||
dashboard=True,
|
||||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="TwoWikiMultiHop",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
@ -127,7 +115,7 @@ async def main():
|
|||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="Musique",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from typing import Optional, Literal
|
||||
|
||||
|
||||
class CacheConfig(BaseSettings):
|
||||
|
|
@ -15,6 +15,7 @@ class CacheConfig(BaseSettings):
|
|||
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
|
||||
"""
|
||||
|
||||
cache_backend: Literal["redis", "fs"] = "fs"
|
||||
caching: bool = False
|
||||
shared_kuzu_lock: bool = False
|
||||
cache_host: str = "localhost"
|
||||
|
|
@ -28,6 +29,7 @@ class CacheConfig(BaseSettings):
|
|||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"cache_backend": self.cache_backend,
|
||||
"caching": self.caching,
|
||||
"shared_kuzu_lock": self.shared_kuzu_lock,
|
||||
"cache_host": self.cache_host,
|
||||
|
|
|
|||
151
cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py
vendored
Normal file
151
cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py
vendored
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time
|
||||
import threading
|
||||
import diskcache as dc
|
||||
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import (
|
||||
CacheConnectionError,
|
||||
SharedKuzuLockRequiresRedisError,
|
||||
)
|
||||
from cognee.infrastructure.files.storage.get_storage_config import get_storage_config
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("FSCacheAdapter")
|
||||
|
||||
|
||||
class FSCacheAdapter(CacheDBInterface):
|
||||
def __init__(self):
|
||||
default_key = "sessions_db"
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
cache_directory = os.path.join(data_root_directory, ".cognee_fs_cache", default_key)
|
||||
os.makedirs(cache_directory, exist_ok=True)
|
||||
self.cache = dc.Cache(directory=cache_directory)
|
||||
self.cache.expire()
|
||||
|
||||
logger.debug(f"FSCacheAdapter initialized with cache directory: {cache_directory}")
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Lock acquisition is not available for filesystem cache backend."""
|
||||
message = "Shared Kuzu lock requires Redis cache backend."
|
||||
logger.error(message)
|
||||
raise SharedKuzuLockRequiresRedisError()
|
||||
|
||||
def release_lock(self):
|
||||
"""Lock release is not available for filesystem cache backend."""
|
||||
message = "Shared Kuzu lock requires Redis cache backend."
|
||||
logger.error(message)
|
||||
raise SharedKuzuLockRequiresRedisError()
|
||||
|
||||
async def add_qa(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
):
|
||||
try:
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
|
||||
qa_entry = {
|
||||
"time": datetime.utcnow().isoformat(),
|
||||
"question": question,
|
||||
"context": context,
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
existing_value = self.cache.get(session_key)
|
||||
if existing_value is not None:
|
||||
value: list = json.loads(existing_value)
|
||||
value.append(qa_entry)
|
||||
else:
|
||||
value = [qa_entry]
|
||||
|
||||
self.cache.set(session_key, json.dumps(value), expire=ttl)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while adding Q&A to diskcache: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
value = self.cache.get(session_key)
|
||||
if value is None:
|
||||
return None
|
||||
entries = json.loads(value)
|
||||
return entries[-last_n:] if len(entries) > last_n else entries
|
||||
|
||||
async def get_all_qas(self, user_id: str, session_id: str):
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
value = self.cache.get(session_key)
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(value)
|
||||
|
||||
async def close(self):
|
||||
if self.cache is not None:
|
||||
self.cache.expire()
|
||||
self.cache.close()
|
||||
|
||||
|
||||
async def main():
|
||||
adapter = FSCacheAdapter()
|
||||
session_id = "demo_session"
|
||||
user_id = "demo_user_id"
|
||||
|
||||
print("\nAdding sample Q/A pairs...")
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Basic DB context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"Historical context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print("\nLatest QA:")
|
||||
latest = await adapter.get_latest_qa(user_id, session_id)
|
||||
print(json.dumps(latest, indent=2))
|
||||
|
||||
print("\nLast 2 QAs:")
|
||||
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
|
||||
print(json.dumps(last_two, indent=2))
|
||||
|
||||
session_id = "session_expire_demo"
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Database context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"History context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print(await adapter.get_all_qas(user_id, session_id))
|
||||
|
||||
await adapter.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
from cognee.infrastructure.databases.cache.fscache.FsCacheAdapter import FSCacheAdapter
|
||||
|
||||
config = get_cache_config()
|
||||
|
||||
|
|
@ -33,20 +35,28 @@ def create_cache_engine(
|
|||
|
||||
Returns:
|
||||
--------
|
||||
- CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic
|
||||
- CacheDBInterface: An instance of the appropriate cache adapter.
|
||||
"""
|
||||
if config.caching:
|
||||
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
||||
|
||||
return RedisAdapter(
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
)
|
||||
if config.cache_backend == "redis":
|
||||
return RedisAdapter(
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
)
|
||||
elif config.cache_backend == "fs":
|
||||
return FSCacheAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cache backend: '{config.cache_backend}'. "
|
||||
f"Supported backends are: 'redis', 'fs'"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -148,3 +148,19 @@ class CacheConnectionError(CogneeConfigurationError):
|
|||
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class SharedKuzuLockRequiresRedisError(CogneeConfigurationError):
|
||||
"""
|
||||
Raised when shared Kuzu locking is requested without configuring the Redis backend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = (
|
||||
"Shared Kuzu lock requires Redis cache backend. Configure Redis to enable shared Kuzu locking."
|
||||
),
|
||||
name: str = "SharedKuzuLockRequiresRedisError",
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
|
|||
- graph_database_username
|
||||
- graph_database_password
|
||||
- graph_database_port
|
||||
- graph_database_key
|
||||
- graph_file_path
|
||||
- graph_model
|
||||
- graph_topology
|
||||
|
|
@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_database_username: str = ""
|
||||
graph_database_password: str = ""
|
||||
graph_database_port: int = 123
|
||||
graph_database_key: str = ""
|
||||
graph_file_path: str = ""
|
||||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
|
|
@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
|
|
@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def create_graph_engine(
|
|||
graph_database_username="",
|
||||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
@ -69,6 +70,7 @@ def create_graph_engine(
|
|||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
database_name=graph_database_name,
|
||||
)
|
||||
|
||||
if graph_database_provider == "neo4j":
|
||||
|
|
|
|||
|
|
@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve nodes and edges filtered by the provided attribute criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
||||
are lists of attribute values to filter by.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
|
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
||||
tuples of (source_id, target_id, relationship_name, properties).
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
|
|
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
|
|||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
return formatted_nodes, formatted_edges
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
Returns:
|
||||
nodes: List[dict] -> Each dict includes "id" and all node properties
|
||||
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
if not all(isinstance(x, str) for x in target_ids):
|
||||
raise CogneeValidationError("target_ids must be a list of strings")
|
||||
|
||||
query = """
|
||||
MATCH (n:Node)-[r]->(m:Node)
|
||||
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
||||
RETURN n.id, {
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}, m.id, {
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}, r.relationship_name, r.properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
if not result:
|
||||
logger.info("No data returned for the supplied IDs")
|
||||
return [], []
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
||||
if n_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(n_props["properties"])
|
||||
n_props.update(additional_props)
|
||||
del n_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
||||
|
||||
if m_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(m_props["properties"])
|
||||
m_props.update(additional_props)
|
||||
del m_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
||||
|
||||
nodes_dict[n_id] = (n_id, n_props)
|
||||
nodes_dict[m_id] = (m_id, m_props)
|
||||
|
||||
edge_props = {}
|
||||
if r_props_raw:
|
||||
try:
|
||||
edge_props = json.loads(r_props_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
||||
|
||||
source_id = edge_props.get("source_node_id", n_id)
|
||||
target_id = edge_props.get("target_node_id", m_id)
|
||||
edges.append((source_id, target_id, r_type, edge_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
|
|
|||
|
|
@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
This version uses a single Cypher query for efficiency.
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
query = """
|
||||
MATCH ()-[r]-()
|
||||
WHERE startNode(r).id IN $target_ids
|
||||
OR endNode(r).id IN $target_ids
|
||||
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
|
||||
RETURN
|
||||
properties(a) AS n_properties,
|
||||
properties(b) AS m_properties,
|
||||
type(r) AS type,
|
||||
properties(r) AS properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for record in result:
|
||||
n_props = record["n_properties"]
|
||||
m_props = record["m_properties"]
|
||||
r_props = record["properties"]
|
||||
r_type = record["type"]
|
||||
|
||||
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
|
||||
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
|
||||
|
||||
source_id = r_props.get("source_node_id", n_props["id"])
|
||||
target_id = r_props.get("target_node_id", m_props["id"])
|
||||
edges.append((source_id, target_id, r_type, r_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
|
|
|
|||
|
|
@ -416,6 +416,15 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
|
||||
pass
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
query = """
|
||||
MATCH (n)
|
||||
RETURN true
|
||||
LIMIT 1;
|
||||
"""
|
||||
query_result = await self._client.query(query)
|
||||
return len(query_result) == 0
|
||||
|
||||
@staticmethod
|
||||
def _get_scored_result(
|
||||
item: dict, with_vector: bool = False, with_score: bool = False
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.modules.data.methods import get_unique_dataset_id
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
if graph_config.graph_database_provider == "kuzu":
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
else:
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Determine vector database URL
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||
else:
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
# Determine graph database URL
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
|
|
@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
|
|||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
|
|||
Instance variables:
|
||||
- vector_db_url: The URL of the vector database.
|
||||
- vector_db_port: The port for the vector database.
|
||||
- vector_db_name: The name of the vector database.
|
||||
- vector_db_key: The key for accessing the vector database.
|
||||
- vector_db_provider: The provider for the vector database.
|
||||
"""
|
||||
|
||||
vector_db_url: str = ""
|
||||
vector_db_port: int = 1234
|
||||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
|
|
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
|
|||
return {
|
||||
"vector_db_url": self.vector_db_url,
|
||||
"vector_db_port": self.vector_db_port,
|
||||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .supported_databases import supported_databases
|
||||
from .embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
|
@ -8,6 +9,7 @@ from functools import lru_cache
|
|||
def create_vector_engine(
|
||||
vector_db_provider: str,
|
||||
vector_db_url: str,
|
||||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
):
|
||||
|
|
@ -27,6 +29,7 @@ def create_vector_engine(
|
|||
- vector_db_url (str): The URL for the vector database instance.
|
||||
- vector_db_port (str): The port for the vector database instance. Required for some
|
||||
providers.
|
||||
- vector_db_name (str): The name of the vector database instance.
|
||||
- vector_db_key (str): The API key or access token for the vector database instance.
|
||||
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
|
||||
'pgvector').
|
||||
|
|
@ -45,6 +48,7 @@ def create_vector_engine(
|
|||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
database_name=vector_db_name,
|
||||
)
|
||||
|
||||
if vector_db_provider.lower() == "pgvector":
|
||||
|
|
@ -133,6 +137,6 @@ def create_vector_engine(
|
|||
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider: {vector_db_provider}. "
|
||||
f"Unsupported vector database provider: {vector_db_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
|
|
@ -18,9 +18,21 @@ class Edge(BaseModel):
|
|||
|
||||
# Mixed usage
|
||||
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
|
||||
|
||||
# With edge_text for rich embedding representation
|
||||
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
|
||||
"""
|
||||
|
||||
weight: Optional[float] = None
|
||||
weights: Optional[Dict[str, float]] = None
|
||||
relationship_type: Optional[str] = None
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
edge_text: Optional[str] = None
|
||||
|
||||
@field_validator("edge_text", mode="before")
|
||||
@classmethod
|
||||
def ensure_edge_text(cls, v, info):
|
||||
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
|
||||
if v is None and info.data.get("relationship_type"):
|
||||
return info.data["relationship_type"]
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
|
|||
file_type = Type("text/plain", "txt")
|
||||
return file_type
|
||||
|
||||
if ext in [".csv"]:
|
||||
file_type = Type("text/csv", "csv")
|
||||
return file_type
|
||||
|
||||
file_type = filetype.guess(file)
|
||||
|
||||
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
structured_output_framework: str = "instructor"
|
||||
llm_instructor_mode: str = ""
|
||||
llm_provider: str = "openai"
|
||||
llm_model: str = "openai/gpt-5-mini"
|
||||
llm_endpoint: str = ""
|
||||
|
|
@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
|
|||
instance.
|
||||
"""
|
||||
return {
|
||||
"llm_instructor_mode": self.llm_instructor_mode.lower(),
|
||||
"provider": self.llm_provider,
|
||||
"model": self.llm_model,
|
||||
"endpoint": self.llm_endpoint,
|
||||
|
|
|
|||
|
|
@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
name = "Anthropic"
|
||||
model: str
|
||||
default_instructor_mode = "anthropic_tools"
|
||||
|
||||
def __init__(self, max_completion_tokens: int, model: str = None):
|
||||
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
|
||||
import anthropic
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
||||
mode=instructor.Mode.ANTHROPIC_TOOLS,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
|
|||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
|
|||
model: str,
|
||||
api_version: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
|
|
@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
model: str,
|
||||
name: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
|
|
@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
model=llm_config.llm_model,
|
||||
transcription_model=llm_config.transcription_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
streaming=llm_config.llm_streaming,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
|
|
@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
|
|
@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
return AnthropicAdapter(
|
||||
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
model=llm_config.llm_model,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
|
|
@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
fallback_model=llm_config.fallback_model,
|
||||
|
|
@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
|
|
@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
default_instructor_mode = "mistral_tools"
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: str = None,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
- aclient
|
||||
"""
|
||||
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
|
||||
self,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
name: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
self.name = name
|
||||
self.model = model
|
||||
|
|
@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
self.endpoint = endpoint
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_openai(
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key),
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
default_instructor_mode = "json_schema_mode"
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
|
|
@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
|
|||
model: str,
|
||||
transcription_model: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
streaming: bool = False,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||
# Make sure all new gpt models will work with this mode as well.
|
||||
if "gpt-5" in model:
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
self.client = instructor.from_litellm(
|
||||
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
|
||||
litellm.completion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class LoaderEngine:
|
|||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
"csv_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@
|
|||
from .text_loader import TextLoader
|
||||
from .audio_loader import AudioLoader
|
||||
from .image_loader import ImageLoader
|
||||
from .csv_loader import CsvLoader
|
||||
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
|
||||
|
|
|
|||
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
from typing import List
|
||||
import csv
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
|
||||
|
||||
class CsvLoader(LoaderInterface):
|
||||
"""
|
||||
Core CSV file loader that handles basic CSV file formats.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [
|
||||
"csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "csv_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
extension: File extension
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
|
||||
"""
|
||||
Load and process the csv file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
row_texts = []
|
||||
row_index = 1
|
||||
|
||||
with open(file_path, "r", encoding=encoding, newline="") as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
row_texts.append(f"Row {row_index}:\n{row_text}\n")
|
||||
row_index += 1
|
||||
|
||||
content = "\n".join(row_texts)
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, content)
|
||||
|
||||
return full_file_path
|
||||
|
|
@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
|
|||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
|
||||
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
|
|
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
|
|||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
|
|
|
|||
|
|
@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
|
|||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.infrastructure.loaders.external import PyPdfLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
|
||||
|
||||
# Registry for loader implementations
|
||||
supported_loaders = {
|
||||
|
|
@ -7,6 +7,7 @@ supported_loaders = {
|
|||
TextLoader.loader_name: TextLoader,
|
||||
ImageLoader.loader_name: ImageLoader,
|
||||
AudioLoader.loader_name: AudioLoader,
|
||||
CsvLoader.loader_name: CsvLoader,
|
||||
}
|
||||
|
||||
# Try adding optional loaders
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from cognee import memify
|
||||
from cognee.context_global_variables import (
|
||||
set_database_global_context_variables,
|
||||
set_session_user_context_variable,
|
||||
)
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.memify import extract_user_sessions, cognify_session
|
||||
|
||||
|
||||
logger = get_logger("persist_sessions_in_knowledge_graph")
|
||||
|
||||
|
||||
async def persist_sessions_in_knowledge_graph_pipeline(
|
||||
user: User,
|
||||
session_ids: Optional[List[str]] = None,
|
||||
dataset: str = "main_dataset",
|
||||
run_in_background: bool = False,
|
||||
):
|
||||
await set_session_user_context_variable(user)
|
||||
dataset_to_write = await get_authorized_existing_datasets(
|
||||
user=user, datasets=[dataset], permission_type="write"
|
||||
)
|
||||
|
||||
if not dataset_to_write:
|
||||
raise CogneeValidationError(
|
||||
message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}",
|
||||
log=False,
|
||||
)
|
||||
|
||||
await set_database_global_context_variables(
|
||||
dataset_to_write[0].id, dataset_to_write[0].owner_id
|
||||
)
|
||||
|
||||
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
|
||||
|
||||
enrichment_tasks = [
|
||||
Task(cognify_session, dataset_id=dataset_to_write[0].id),
|
||||
]
|
||||
|
||||
result = await memify(
|
||||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
dataset=dataset_to_write[0].id,
|
||||
data=[{}],
|
||||
run_in_background=run_in_background,
|
||||
)
|
||||
|
||||
logger.info("Session persistence pipeline completed")
|
||||
return result
|
||||
35
cognee/modules/chunking/CsvChunker.py
Normal file
35
cognee/modules/chunking/CsvChunker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CsvChunker(Chunker):
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
if content_text is None:
|
||||
continue
|
||||
|
||||
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
|
||||
if chunk_data["chunk_size"] <= self.max_chunk_size:
|
||||
yield DocumentChunk(
|
||||
id=chunk_data["chunk_id"],
|
||||
text=chunk_data["text"],
|
||||
chunk_size=chunk_data["chunk_size"],
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
|
||||
)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.tasks.temporal_graph.models import Event
|
||||
|
|
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextChunkerWithOverlap(Chunker):
|
||||
def __init__(
|
||||
self,
|
||||
document,
|
||||
get_text: callable,
|
||||
max_chunk_size: int,
|
||||
chunk_overlap_ratio: float = 0.0,
|
||||
get_chunk_data: callable = None,
|
||||
):
|
||||
super().__init__(document, get_text, max_chunk_size)
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
self.chunk_overlap_ratio = chunk_overlap_ratio
|
||||
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
|
||||
|
||||
if get_chunk_data is not None:
|
||||
self.get_chunk_data = get_chunk_data
|
||||
elif chunk_overlap_ratio > 0:
|
||||
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, paragraph_max_size, batch_paragraphs=True
|
||||
)
|
||||
else:
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, self.max_chunk_size, batch_paragraphs=True
|
||||
)
|
||||
|
||||
def _accumulation_overflows(self, chunk_data):
|
||||
"""Check if adding chunk_data would exceed max_chunk_size."""
|
||||
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
|
||||
|
||||
def _accumulate_chunk_data(self, chunk_data):
|
||||
"""Add chunk_data to the current accumulation."""
|
||||
self._accumulated_chunk_data.append(chunk_data)
|
||||
self._accumulated_size += chunk_data["chunk_size"]
|
||||
|
||||
def _clear_accumulation(self):
|
||||
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
|
||||
if self.chunk_overlap == 0:
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
return
|
||||
|
||||
# Keep chunk_data from the end that fit in overlap
|
||||
overlap_chunk_data = []
|
||||
overlap_size = 0
|
||||
|
||||
for chunk_data in reversed(self._accumulated_chunk_data):
|
||||
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
|
||||
overlap_chunk_data.insert(0, chunk_data)
|
||||
overlap_size += chunk_data["chunk_size"]
|
||||
else:
|
||||
break
|
||||
|
||||
self._accumulated_chunk_data = overlap_chunk_data
|
||||
self._accumulated_size = overlap_size
|
||||
|
||||
def _create_chunk(self, text, size, cut_type, chunk_id=None):
|
||||
"""Create a DocumentChunk with standard metadata."""
|
||||
try:
|
||||
return DocumentChunk(
|
||||
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text=text,
|
||||
chunk_size=size,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=cut_type,
|
||||
contains=[],
|
||||
metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
def _create_chunk_from_accumulation(self):
|
||||
"""Create a DocumentChunk from current accumulated chunk_data."""
|
||||
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
|
||||
return self._create_chunk(
|
||||
text=chunk_text,
|
||||
size=self._accumulated_size,
|
||||
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
|
||||
)
|
||||
|
||||
def _emit_chunk(self, chunk_data):
|
||||
"""Emit a chunk when accumulation overflows."""
|
||||
if len(self._accumulated_chunk_data) > 0:
|
||||
chunk = self._create_chunk_from_accumulation()
|
||||
self._clear_accumulation()
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
else:
|
||||
# Handle single chunk_data exceeding max_chunk_size
|
||||
chunk = self._create_chunk(
|
||||
text=chunk_data["text"],
|
||||
size=chunk_data["chunk_size"],
|
||||
cut_type=chunk_data["cut_type"],
|
||||
chunk_id=chunk_data["chunk_id"],
|
||||
)
|
||||
|
||||
self.chunk_index += 1
|
||||
return chunk
|
||||
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
for chunk_data in self.get_chunk_data(content_text):
|
||||
if not self._accumulation_overflows(chunk_data):
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
continue
|
||||
|
||||
yield self._emit_chunk(chunk_data)
|
||||
|
||||
if len(self._accumulated_chunk_data) == 0:
|
||||
return
|
||||
|
||||
yield self._create_chunk_from_accumulation()
|
||||
|
|
@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset
|
|||
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
|
||||
from .get_data import get_data
|
||||
from .get_unique_dataset_id import get_unique_dataset_id
|
||||
from .get_unique_data_id import get_unique_data_id
|
||||
from .get_authorized_existing_datasets import get_authorized_existing_datasets
|
||||
from .get_dataset_ids import get_dataset_ids
|
||||
|
||||
|
|
|
|||
|
|
@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
|
|||
.options(joinedload(Dataset.data))
|
||||
.filter(Dataset.name == dataset_name)
|
||||
.filter(Dataset.owner_id == owner_id)
|
||||
.filter(Dataset.tenant_id == user.tenant_id)
|
||||
)
|
||||
).first()
|
||||
|
||||
if dataset is None:
|
||||
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
|
||||
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
|
||||
dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
|
||||
dataset.owner_id = owner_id
|
||||
dataset = Dataset(
|
||||
id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
session.add(dataset)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user):
|
|||
# Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.)
|
||||
user_datasets = await get_datasets(user.id)
|
||||
# Filter out non name mentioned datasets
|
||||
dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets]
|
||||
dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets]
|
||||
# Filter out non current tenant datasets
|
||||
dataset_ids = [
|
||||
dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id
|
||||
]
|
||||
else:
|
||||
raise DatasetTypeError(
|
||||
f"One or more of the provided dataset types is not handled: f{datasets}"
|
||||
|
|
|
|||
68
cognee/modules/data/methods/get_unique_data_id.py
Normal file
68
cognee/modules/data/methods/get_unique_data_id.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from uuid import uuid5, NAMESPACE_OID, UUID
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models.Data import Data
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def get_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Function returns a unique UUID for data based on data identifier, user id and tenant id.
|
||||
If data with legacy ID exists, return that ID to maintain compatibility.
|
||||
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
tenant_id: UUID of the tenant for which data is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
|
||||
def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Deprecated function, returns a unique UUID for data based on data identifier and user id.
|
||||
Needed to support legacy data without tenant information.
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
# return UUID hash of file contents + owner id + tenant_id
|
||||
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}")
|
||||
|
||||
def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Function returns a unique UUID for data based on data identifier, user id and tenant id.
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
tenant_id: UUID of the tenant for which data is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
# return UUID hash of file contents + owner id + tenant_id
|
||||
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}")
|
||||
|
||||
# Get all possible data_id values
|
||||
data_id = {
|
||||
"modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user),
|
||||
"legacy_data_id": _get_deprecated_unique_data_id(
|
||||
data_identifier=data_identifier, user=user
|
||||
),
|
||||
}
|
||||
|
||||
# Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
legacy_data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"]))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not legacy_data_point:
|
||||
return data_id["modern_data_id"]
|
||||
return data_id["legacy_data_id"]
|
||||
|
|
@ -1,9 +1,71 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.modules.users.models import User
|
||||
from typing import Union
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
||||
async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
|
||||
"""
|
||||
Function returns a unique UUID for dataset based on dataset name, user id and tenant id.
|
||||
If dataset with legacy ID exists, return that ID to maintain compatibility.
|
||||
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: User object adding the dataset
|
||||
tenant_id: UUID of the tenant for which dataset is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
"""
|
||||
Legacy function, returns a unique UUID for dataset based on dataset name and user id.
|
||||
Needed to support legacy datasets without tenant information.
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: Current User object adding the dataset
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
|
||||
|
||||
def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
"""
|
||||
Returns a unique UUID for dataset based on dataset name, user id and tenant_id.
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: Current User object adding the dataset
|
||||
tenant_id: UUID of the tenant for which dataset is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}")
|
||||
|
||||
# Get all possible dataset_id values
|
||||
dataset_id = {
|
||||
"modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user),
|
||||
"legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user),
|
||||
}
|
||||
|
||||
# Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
legacy_dataset = (
|
||||
await session.execute(
|
||||
select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"])
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not legacy_dataset:
|
||||
return dataset_id["modern_dataset_id"]
|
||||
return dataset_id["legacy_dataset_id"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class Dataset(Base):
|
|||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
owner_id = Column(UUID, index=True)
|
||||
tenant_id = Column(UUID, index=True, nullable=True)
|
||||
|
||||
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")
|
||||
|
||||
|
|
@ -36,5 +37,6 @@ class Dataset(Base):
|
|||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"ownerId": str(self.owner_id),
|
||||
"tenantId": str(self.tenant_id),
|
||||
"data": [data.to_json() for data in self.data],
|
||||
}
|
||||
|
|
|
|||
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import io
|
||||
import csv
|
||||
from typing import Type
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class CsvDocument(Document):
|
||||
type: str = "csv"
|
||||
mime_type: str = "text/csv"
|
||||
|
||||
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
|
||||
async def get_text():
|
||||
async with open_data_file(
|
||||
self.raw_data_location, mode="r", encoding="utf-8", newline=""
|
||||
) as file:
|
||||
content = file.read()
|
||||
file_like_obj = io.StringIO(content)
|
||||
reader = csv.DictReader(file_like_obj)
|
||||
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
if not row_text.strip():
|
||||
break
|
||||
yield row_text
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
@ -4,3 +4,4 @@ from .TextDocument import TextDocument
|
|||
from .ImageDocument import ImageDocument
|
||||
from .AudioDocument import AudioDocument
|
||||
from .UnstructuredDocument import UnstructuredDocument
|
||||
from .CsvDocument import CsvDocument
|
||||
|
|
|
|||
|
|
@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
node_type,
|
||||
node_name,
|
||||
):
|
||||
"""Retrieve subgraph based on node type and name."""
|
||||
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodeset projected from the database."
|
||||
)
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_full_or_id_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
relevant_ids_to_filter,
|
||||
):
|
||||
"""Retrieve full or ID-filtered graph with fallback."""
|
||||
if relevant_ids_to_filter is None:
|
||||
logger.info("Retrieving full graph.")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
||||
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
||||
logger.info("Retrieving ID-filtered graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
||||
else:
|
||||
logger.info("Retrieving full graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn()
|
||||
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
||||
logger.warning(
|
||||
"Id filtered graph returned empty, falling back to full graph retrieval."
|
||||
)
|
||||
logger.info("Retrieving full graph")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError("Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
memory_fragment_filter,
|
||||
):
|
||||
"""Retrieve graph filtered by attributes."""
|
||||
logger.info("Retrieving graph filtered by memory fragment")
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def project_graph_from_db(
|
||||
self,
|
||||
adapter: Union[GraphDBInterface],
|
||||
|
|
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: float = 3.5,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidDimensionsError()
|
||||
try:
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await self._get_nodeset_subgraph(
|
||||
adapter, node_type, node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
|
||||
adapter, relevant_ids_to_filter
|
||||
)
|
||||
else:
|
||||
nodes_data, edges_data = await self._get_filtered_graph(
|
||||
adapter, memory_fragment_filter
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodetes projected from the database."
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
self.add_node(
|
||||
Node(
|
||||
str(node_id),
|
||||
node_attributes,
|
||||
dimension=node_dimension,
|
||||
node_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
|
|
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
attributes=edge_attributes,
|
||||
directed=directed,
|
||||
dimension=edge_dimension,
|
||||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
|
|
@ -171,8 +233,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
distance = embedding_map.get(relationship_type, None)
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,17 @@ class Node:
|
|||
status: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
|
||||
self,
|
||||
node_id: str,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
dimension: int = 1,
|
||||
node_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = node_penalty
|
||||
self.skeleton_neighbours = []
|
||||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
|
@ -105,13 +109,14 @@ class Edge:
|
|||
attributes: Optional[Dict[str, Any]] = None,
|
||||
directed: bool = True,
|
||||
dimension: int = 1,
|
||||
edge_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = edge_penalty
|
||||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import (
|
||||
|
|
@ -243,10 +244,26 @@ def _process_graph_nodes(
|
|||
ontology_relationships,
|
||||
)
|
||||
|
||||
# Add entity to data chunk
|
||||
if data_chunk.contains is None:
|
||||
data_chunk.contains = []
|
||||
data_chunk.contains.append(entity_node)
|
||||
|
||||
edge_text = "; ".join(
|
||||
[
|
||||
"relationship_name: contains",
|
||||
f"entity_name: {entity_node.name}",
|
||||
f"entity_description: {entity_node.description}",
|
||||
]
|
||||
)
|
||||
|
||||
data_chunk.contains.append(
|
||||
(
|
||||
Edge(
|
||||
relationship_type="contains",
|
||||
edge_text=edge_text,
|
||||
),
|
||||
entity_node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _process_graph_edges(
|
||||
|
|
|
|||
|
|
@ -1,71 +1,70 @@
|
|||
import string
|
||||
from typing import List
|
||||
from collections import Counter
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
|
||||
def _get_top_n_frequent_words(
|
||||
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
|
||||
) -> str:
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
return separator.join(top_words)
|
||||
|
||||
|
||||
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
"""Creates a title by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
|
||||
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id in nodes:
|
||||
continue
|
||||
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _create_title_from_text(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
nodes = _extract_nodes_from_edges(retrieved_edges)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
|
||||
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
import string
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
from collections import Counter
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _top_n_words(text, top_n=first_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
nodes = _get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
|
||||
connections = []
|
||||
for edge in retrieved_edges:
|
||||
source_name = nodes[edge.node1.id]["name"]
|
||||
target_name = nodes[edge.node2.id]["name"]
|
||||
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
|
||||
|
||||
connection_section = "\n".join(connections)
|
||||
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from uuid import UUID
|
||||
from .data_types import IngestionData
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.methods import get_unique_data_id
|
||||
|
||||
|
||||
def identify(data: IngestionData, user: User) -> str:
|
||||
async def identify(data: IngestionData, user: User) -> UUID:
|
||||
data_content_hash: str = data.get_identifier()
|
||||
|
||||
# return UUID hash of file contents + owner id
|
||||
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
|
||||
return await get_unique_data_id(data_identifier=data_content_hash, user=user)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import io
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import cognee
|
||||
|
||||
|
||||
def wrap_in_async_handler(user_code: str) -> str:
|
||||
return (
|
||||
|
|
@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
|
|||
|
||||
environment["print"] = customPrintFunction
|
||||
environment["running_loop"] = loop
|
||||
environment["cognee"] = cognee
|
||||
|
||||
try:
|
||||
exec(code, environment)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import difflib
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional, Any, Union
|
||||
from typing import List, Tuple, Dict, Optional, Any, Union, IO
|
||||
from rdflib import Graph, URIRef, RDF, RDFS, OWL
|
||||
|
||||
from cognee.modules.ontology.exceptions import (
|
||||
|
|
@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
ontology_file: Optional[Union[str, List[str]]] = None,
|
||||
ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
|
||||
matching_strategy: Optional[MatchingStrategy] = None,
|
||||
) -> None:
|
||||
super().__init__(matching_strategy)
|
||||
self.ontology_file = ontology_file
|
||||
try:
|
||||
files_to_load = []
|
||||
self.graph = None
|
||||
if ontology_file is not None:
|
||||
if isinstance(ontology_file, str):
|
||||
files_to_load = []
|
||||
file_objects = []
|
||||
|
||||
if hasattr(ontology_file, "read"):
|
||||
file_objects = [ontology_file]
|
||||
elif isinstance(ontology_file, str):
|
||||
files_to_load = [ontology_file]
|
||||
elif isinstance(ontology_file, list):
|
||||
files_to_load = ontology_file
|
||||
if all(hasattr(item, "read") for item in ontology_file):
|
||||
file_objects = ontology_file
|
||||
else:
|
||||
files_to_load = ontology_file
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
|
||||
f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
|
||||
)
|
||||
|
||||
if files_to_load:
|
||||
self.graph = Graph()
|
||||
loaded_files = []
|
||||
for file_path in files_to_load:
|
||||
if os.path.exists(file_path):
|
||||
self.graph.parse(file_path)
|
||||
loaded_files.append(file_path)
|
||||
logger.info("Ontology loaded successfully from file: %s", file_path)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ontology file '%s' not found. Skipping this file.",
|
||||
file_path,
|
||||
if file_objects:
|
||||
self.graph = Graph()
|
||||
loaded_objects = []
|
||||
for file_obj in file_objects:
|
||||
try:
|
||||
content = file_obj.read()
|
||||
self.graph.parse(data=content, format="xml")
|
||||
loaded_objects.append(file_obj)
|
||||
logger.info("Ontology loaded successfully from file object")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse ontology file object: %s", str(e))
|
||||
|
||||
if not loaded_objects:
|
||||
logger.info(
|
||||
"No valid ontology file objects found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
else:
|
||||
logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
|
||||
|
||||
if not loaded_files:
|
||||
logger.info(
|
||||
"No valid ontology files found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
elif files_to_load:
|
||||
self.graph = Graph()
|
||||
loaded_files = []
|
||||
for file_path in files_to_load:
|
||||
if os.path.exists(file_path):
|
||||
self.graph.parse(file_path)
|
||||
loaded_files.append(file_path)
|
||||
logger.info("Ontology loaded successfully from file: %s", file_path)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ontology file '%s' not found. Skipping this file.",
|
||||
file_path,
|
||||
)
|
||||
|
||||
if not loaded_files:
|
||||
logger.info(
|
||||
"No valid ontology files found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
else:
|
||||
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
||||
else:
|
||||
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
||||
logger.info(
|
||||
"No ontology file provided. No owl ontology will be attached to the graph."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No ontology file provided. No owl ontology will be attached to the graph."
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
|
|||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
data_id = await ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||
|
|
@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
return None
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
|
|
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
fetched if not provided. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import Any, List, Optional, Type
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""Generates a response using the query and optional context (triplets)."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
|
|
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""Generates a response using the query and optional context."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
|
|||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Optional, List, Type
|
||||
from typing import Optional, List, Type, Any
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
|
|
@ -56,7 +60,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
completions based on them.
|
||||
|
|
@ -76,6 +81,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
defaults to 'default_session'. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -143,6 +149,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -152,6 +159,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import (
|
||||
generate_structured_completion,
|
||||
generate_completion,
|
||||
summarize_text,
|
||||
)
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
|
|
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
questions based on reasoning. The public methods are:
|
||||
|
||||
- get_completion
|
||||
- get_structured_completion
|
||||
|
||||
Instance variables include:
|
||||
- validation_system_prompt_path
|
||||
|
|
@ -66,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -75,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
|
|
@ -121,7 +124,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_structured_completion(
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
|
|
@ -165,24 +168,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
return completion, context_text, triplets
|
||||
|
||||
async def get_structured_completion(
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter: int = 4,
|
||||
max_iter=4,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generate structured completion responses based on a user query and contextual information.
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method applies the same chain-of-thought logic as get_completion but returns
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs. It returns
|
||||
structured output using the provided response model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
|
@ -192,7 +199,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The generated structured completion based on the response model.
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -228,45 +236,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
completion = await self.get_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
max_iter=max_iter,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
|
|
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.wide_search_top_k = wide_search_top_k
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""
|
||||
|
|
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
wide_search_top_k=self.wide_search_top_k,
|
||||
triplet_distance_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -141,12 +147,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
return context
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
|
|
@ -188,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -197,6 +209,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
|
|
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
|
|
@ -146,8 +150,12 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
||||
|
|
@ -159,6 +167,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -186,6 +195,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -194,6 +204,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ async def get_memory_fragment(
|
|||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
if properties_to_project is None:
|
||||
|
|
@ -71,9 +73,11 @@ async def get_memory_fragment(
|
|||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
|
|
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
|
|||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
|
|||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
|
|||
return []
|
||||
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
|
||||
async def generate_structured_completion(
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
|
|
@ -12,7 +12,7 @@ async def generate_structured_completion(
|
|||
conversation_history: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""Generates a structured completion using LLM with given context and prompts."""
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
|
@ -28,26 +28,6 @@ async def generate_structured_completion(
|
|||
)
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
return await generate_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
system_prompt_path: str = "summarize_search_results.txt",
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue