Merge branch 'dev' into batch-document-handling
This commit is contained in:
commit
c1d633fb75
199 changed files with 16428 additions and 22653 deletions
|
|
@ -16,7 +16,7 @@
|
|||
STRUCTURED_OUTPUT_FRAMEWORK="instructor"
|
||||
|
||||
LLM_API_KEY="your_api_key"
|
||||
LLM_MODEL="openai/gpt-4o-mini"
|
||||
LLM_MODEL="openai/gpt-5-mini"
|
||||
LLM_PROVIDER="openai"
|
||||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
|
|
@ -30,10 +30,13 @@ EMBEDDING_DIMENSIONS=3072
|
|||
EMBEDDING_MAX_TOKENS=8191
|
||||
# If embedding key is not provided same key set for LLM_API_KEY will be used
|
||||
#EMBEDDING_API_KEY="your_api_key"
|
||||
# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch,
|
||||
# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models
|
||||
#EMBEDDING_BATCH_SIZE=2048
|
||||
|
||||
# If using BAML structured output these env variables will be used
|
||||
BAML_LLM_PROVIDER=openai
|
||||
BAML_LLM_MODEL="gpt-4o-mini"
|
||||
BAML_LLM_MODEL="gpt-5-mini"
|
||||
BAML_LLM_ENDPOINT=""
|
||||
BAML_LLM_API_KEY="your_api_key"
|
||||
BAML_LLM_API_VERSION=""
|
||||
|
|
@ -52,18 +55,18 @@ BAML_LLM_API_VERSION=""
|
|||
################################################################################
|
||||
# Configure storage backend (local filesystem or S3)
|
||||
# STORAGE_BACKEND="local" # Default: uses local filesystem
|
||||
#
|
||||
#
|
||||
# -- To switch to S3 storage, uncomment and fill these: ---------------------
|
||||
# STORAGE_BACKEND="s3"
|
||||
# STORAGE_BUCKET_NAME="your-bucket-name"
|
||||
# AWS_REGION="us-east-1"
|
||||
# AWS_ACCESS_KEY_ID="your-access-key"
|
||||
# AWS_SECRET_ACCESS_KEY="your-secret-key"
|
||||
#
|
||||
#
|
||||
# -- S3 Root Directories (optional) -----------------------------------------
|
||||
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
|
||||
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
|
||||
#
|
||||
#
|
||||
# -- Cache Directory (auto-configured for S3) -------------------------------
|
||||
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
|
||||
# To override the automatic S3 cache location, uncomment:
|
||||
|
|
@ -203,6 +206,16 @@ LITELLM_LOG="ERROR"
|
|||
# DEFAULT_USER_EMAIL=""
|
||||
# DEFAULT_USER_PASSWORD=""
|
||||
|
||||
################################################################################
|
||||
# 📂 AWS Settings
|
||||
################################################################################
|
||||
|
||||
#AWS_REGION=""
|
||||
#AWS_ENDPOINT_URL=""
|
||||
#AWS_ACCESS_KEY_ID=""
|
||||
#AWS_SECRET_ACCESS_KEY=""
|
||||
#AWS_SESSION_TOKEN=""
|
||||
|
||||
------------------------------- END OF POSSIBLE SETTINGS -------------------------------
|
||||
|
||||
|
||||
|
|
|
|||
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
|
@ -58,7 +58,7 @@ body:
|
|||
- Python version: [e.g. 3.9.0]
|
||||
- Cognee version: [e.g. 0.1.0]
|
||||
- LLM Provider: [e.g. OpenAI, Ollama]
|
||||
- Database: [e.g. Neo4j, FalkorDB]
|
||||
- Database: [e.g. Neo4j]
|
||||
validations:
|
||||
required: true
|
||||
|
||||
|
|
|
|||
2
.github/actions/cognee_setup/action.yml
vendored
2
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -41,4 +41,4 @@ runs:
|
|||
EXTRA_ARGS="$EXTRA_ARGS --extra $extra"
|
||||
done
|
||||
fi
|
||||
uv sync --extra api --extra docs --extra evals --extra gemini --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
|
||||
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
|
||||
|
|
|
|||
19
.github/pull_request_template.md
vendored
19
.github/pull_request_template.md
vendored
|
|
@ -1,8 +1,8 @@
|
|||
<!-- .github/pull_request_template.md -->
|
||||
|
||||
## Description
|
||||
<!--
|
||||
Please provide a clear, human-generated description of the changes in this PR.
|
||||
<!--
|
||||
Please provide a clear, human-generated description of the changes in this PR.
|
||||
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
|
||||
-->
|
||||
|
||||
|
|
@ -16,15 +16,6 @@ DO NOT use AI-generated descriptions. We want to understand your thought process
|
|||
- [ ] Performance improvement
|
||||
- [ ] Other (please specify):
|
||||
|
||||
## Changes Made
|
||||
<!-- List the specific changes made in this PR -->
|
||||
-
|
||||
-
|
||||
-
|
||||
|
||||
## Testing
|
||||
<!-- Describe how you tested your changes -->
|
||||
|
||||
## Screenshots/Videos (if applicable)
|
||||
<!-- Add screenshots or videos to help explain your changes -->
|
||||
|
||||
|
|
@ -40,11 +31,5 @@ DO NOT use AI-generated descriptions. We want to understand your thought process
|
|||
- [ ] I have linked any relevant issues in the description
|
||||
- [ ] My commits have clear and descriptive messages
|
||||
|
||||
## Related Issues
|
||||
<!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" -->
|
||||
|
||||
## Additional Notes
|
||||
<!-- Add any additional notes, concerns, or context for reviewers -->
|
||||
|
||||
## DCO Affirmation
|
||||
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
|
||||
|
|
|
|||
10
.github/workflows/db_examples_tests.yml
vendored
10
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -54,6 +54,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: dev
|
||||
|
|
@ -66,9 +70,9 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: |
|
||||
uv run python examples/database_examples/neo4j_example.py
|
||||
|
||||
|
|
|
|||
73
.github/workflows/distributed_test.yml
vendored
Normal file
73
.github/workflows/distributed_test.yml
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
name: Distributed Cognee test with modal
|
||||
permissions:
|
||||
contents: read
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
python-version:
|
||||
required: false
|
||||
type: string
|
||||
default: '3.11.x'
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
run-server-start-test:
|
||||
name: Distributed Cognee test (Modal)
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "distributed postgres"
|
||||
|
||||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
|
||||
DB_PROVIDER: "postgres"
|
||||
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
|
||||
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
|
||||
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
|
||||
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
|
||||
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
|
||||
VECTOR_DB_PROVIDER: "pgvector"
|
||||
COGNEE_DISTRIBUTED: "true"
|
||||
run: uv run modal run ./distributed/entrypoint.py
|
||||
53
.github/workflows/examples_tests.yml
vendored
53
.github/workflows/examples_tests.yml
vendored
|
|
@ -1,5 +1,8 @@
|
|||
name: Reusable Examples Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
|
|
@ -131,3 +134,53 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/memify_coding_agent_example.py
|
||||
|
||||
test-permissions-example:
|
||||
name: Run Permissions Example
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/permissions_example.py
|
||||
test_docling_add:
|
||||
name: Run Add with Docling Test
|
||||
runs-on: macos-15
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: 'docling'
|
||||
|
||||
- name: Run Docling Test
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_add_docling_document.py
|
||||
|
|
|
|||
10
.github/workflows/graph_db_tests.yml
vendored
10
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -71,6 +71,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
|
|
@ -83,9 +87,9 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_neo4j.py
|
||||
|
||||
- name: Run Weighted Edges Tests with Neo4j
|
||||
|
|
|
|||
|
|
@ -186,6 +186,10 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Install specific db dependency
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -206,9 +210,9 @@ jobs:
|
|||
env:
|
||||
ENV: 'dev'
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
|
|||
47
.github/workflows/search_db_tests.yml
vendored
47
.github/workflows/search_db_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
|||
name: Search test for Neo4j/LanceDB/Sqlite
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -77,6 +63,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -94,9 +84,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_search_db.py
|
||||
|
||||
run-kuzu-pgvector-postgres-search-tests:
|
||||
|
|
@ -158,19 +148,6 @@ jobs:
|
|||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
|
|
@ -196,6 +173,10 @@ jobs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -213,9 +194,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
DB_PROVIDER: 'postgres'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
24
.github/workflows/temporal_graph_tests.yml
vendored
24
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
|||
name: Temporal Graph test Neo4j (lancedb + sqlite)
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -77,6 +63,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -94,9 +84,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_temporal_graph.py
|
||||
|
||||
run_temporal_graph_kuzu_postgres_pgvector:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -115,7 +115,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -151,7 +151,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -180,7 +180,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -210,7 +210,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
4
.github/workflows/test_llms.yml
vendored
4
.github/workflows/test_llms.yml
vendored
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: "gemini"
|
||||
LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
LLM_MODEL: "gemini/gemini-1.5-flash"
|
||||
LLM_MODEL: "gemini/gemini-2.0-flash"
|
||||
EMBEDDING_PROVIDER: "gemini"
|
||||
EMBEDDING_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
EMBEDDING_MODEL: "gemini/text-embedding-004"
|
||||
|
|
@ -83,4 +83,4 @@ jobs:
|
|||
EMBEDDING_MODEL: "openai/text-embedding-3-large"
|
||||
EMBEDDING_DIMENSIONS: "3072"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
|
|
|||
6
.github/workflows/test_s3_file_storage.yml
vendored
6
.github/workflows/test_s3_file_storage.yml
vendored
|
|
@ -6,8 +6,12 @@ on:
|
|||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
test-gemini:
|
||||
test-s3-storage:
|
||||
name: Run S3 File Storage Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
|
|
|||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -27,6 +27,12 @@ jobs:
|
|||
uses: ./.github/workflows/e2e_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
distributed-tests:
|
||||
name: Distributed Cognee Test
|
||||
needs: [ basic-tests, e2e-tests, graph-db-tests ]
|
||||
uses: ./.github/workflows/distributed_test.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
uses: ./.github/workflows/cli_tests.yml
|
||||
|
|
@ -104,7 +110,7 @@ jobs:
|
|||
|
||||
db-examples-tests:
|
||||
name: DB Examples Tests
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
|
||||
uses: ./.github/workflows/db_examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
|
|
|
|||
27
.github/workflows/vector_db_tests.yml
vendored
27
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -101,3 +101,30 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_pgvector.py
|
||||
|
||||
run-lancedb-tests:
|
||||
name: LanceDB Tests
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run LanceDB Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
7
.github/workflows/weighted_edges_tests.yml
vendored
7
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -86,12 +86,19 @@ jobs:
|
|||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
|
||||
run: |
|
||||
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ RUN apt-get update && apt-get install -y \
|
|||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
cmake \
|
||||
clang \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
|
@ -31,7 +32,7 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
|
|||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
|
||||
|
||||
# Copy Alembic configuration
|
||||
COPY alembic.ini /app/alembic.ini
|
||||
|
|
@ -42,7 +43,7 @@ COPY alembic/ /app/alembic
|
|||
COPY ./cognee /app/cognee
|
||||
COPY ./distributed /app/distributed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
|
||||
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
|
|
|
|||
133
README.md
133
README.md
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
<br />
|
||||
|
||||
cognee - Memory for AI Agents in 5 lines of code
|
||||
cognee - Memory for AI Agents in 6 lines of code
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.youtube.com/watch?v=1bezuvLwJmw&t=2s">Demo</a>
|
||||
|
|
@ -43,12 +43,10 @@
|
|||
|
||||
|
||||
|
||||
**🚀 We launched Cogwit beta (Fully-hosted AI Memory): Sign up [here](https://platform.cognee.ai/)! 🚀**
|
||||
|
||||
|
||||
Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Extract, Cognify, Load) pipelines.
|
||||
|
||||
More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals)
|
||||
|
||||
<p align="center">
|
||||
🌐 Available Languages
|
||||
:
|
||||
|
|
@ -70,53 +68,50 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
|
|||
</div>
|
||||
|
||||
|
||||
## Features
|
||||
|
||||
- Interconnect and retrieve your past conversations, documents, images and audio transcriptions
|
||||
- Replaces RAG systems and reduces developer effort, and cost.
|
||||
- Load data to graph and vector databases using only Pydantic
|
||||
- Manipulate your data while ingesting from 30+ data sources
|
||||
|
||||
## Get Started
|
||||
|
||||
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a>
|
||||
|
||||
|
||||
## About cognee
|
||||
|
||||
cognee works locally and stores your data on your device.
|
||||
Our hosted solution is just our deployment of OSS cognee on Modal, with the goal of making development and productionization easier.
|
||||
|
||||
## Contributing
|
||||
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
|
||||
Self-hosted package:
|
||||
|
||||
- Interconnects any kind of documents: past conversations, files, images, and audio transcriptions
|
||||
- Replaces RAG systems with a memory layer based on graphs and vectors
|
||||
- Reduces developer effort and cost, while increasing quality and precision
|
||||
- Provides Pythonic data pipelines that manage data ingestion from 30+ data sources
|
||||
- Is highly customizable with custom tasks, pipelines, and a set of built-in search endpoints
|
||||
|
||||
Hosted platform:
|
||||
- Includes a managed UI and a [hosted solution](https://www.cognee.ai)
|
||||
|
||||
|
||||
|
||||
## Self-Hosted (Open Source)
|
||||
|
||||
|
||||
## 📦 Installation
|
||||
### 📦 Installation
|
||||
|
||||
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager.
|
||||
|
||||
Cognee supports Python 3.10 to 3.13
|
||||
Cognee supports Python 3.10 to 3.12
|
||||
|
||||
### With pip
|
||||
#### With uv
|
||||
|
||||
```bash
|
||||
pip install cognee
|
||||
uv pip install cognee
|
||||
```
|
||||
|
||||
## Local Cognee installation
|
||||
Detailed instructions can be found in our [docs](https://docs.cognee.ai/getting-started/installation#environment-configuration)
|
||||
|
||||
You can install the local Cognee repo using **uv**, **pip** and **poetry**.
|
||||
For local pip installation please make sure your pip version is above version 21.3.
|
||||
### 💻 Basic Usage
|
||||
|
||||
### with UV with all optional dependencies
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
## 💻 Basic Usage
|
||||
|
||||
### Setup
|
||||
#### Setup
|
||||
|
||||
```
|
||||
import os
|
||||
|
|
@ -125,10 +120,14 @@ os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
|
|||
```
|
||||
|
||||
You can also set the variables by creating .env file, using our <a href="https://github.com/topoteretes/cognee/blob/main/.env.template">template.</a>
|
||||
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai">documentation</a>
|
||||
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai/setup-configuration/llm-providers">documentation</a>
|
||||
|
||||
|
||||
### Simple example
|
||||
#### Simple example
|
||||
|
||||
|
||||
|
||||
##### Python
|
||||
|
||||
This script will run the default pipeline:
|
||||
|
||||
|
|
@ -139,13 +138,16 @@ import asyncio
|
|||
|
||||
async def main():
|
||||
# Add text to cognee
|
||||
await cognee.add("Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.")
|
||||
await cognee.add("Cognee turns documents into AI memory.")
|
||||
|
||||
# Generate the knowledge graph
|
||||
await cognee.cognify()
|
||||
|
||||
# Add memory algorithms to the graph
|
||||
await cognee.memify()
|
||||
|
||||
# Query the knowledge graph
|
||||
results = await cognee.search("Tell me about NLP")
|
||||
results = await cognee.search("What does cognee do?")
|
||||
|
||||
# Display the results
|
||||
for result in results:
|
||||
|
|
@ -158,33 +160,38 @@ if __name__ == '__main__':
|
|||
```
|
||||
Example output:
|
||||
```
|
||||
Natural Language Processing (NLP) is a cross-disciplinary and interdisciplinary field that involves computer science and information retrieval. It focuses on the interaction between computers and human language, enabling machines to understand and process natural language.
|
||||
Cognee turns documents into AI memory.
|
||||
|
||||
```
|
||||
##### Via CLI
|
||||
|
||||
## Our paper is out! <a href="https://arxiv.org/abs/2505.24478" target="_blank" rel="noopener noreferrer">Read here</a>
|
||||
Let's get the basics covered
|
||||
|
||||
```
|
||||
cognee-cli add "Cognee turns documents into AI memory."
|
||||
|
||||
cognee-cli cognify
|
||||
|
||||
cognee-cli search "What does cognee do?"
|
||||
cognee-cli delete --all
|
||||
|
||||
```
|
||||
or run
|
||||
```
|
||||
cognee-cli -ui
|
||||
```
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="assets/cognee-paper.png" alt="cognee paper" width="100%" />
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
## Cognee UI
|
||||
|
||||
You can also cognify your files and query using cognee UI.
|
||||
### Hosted Platform
|
||||
|
||||
<img src="assets/cognee-new-ui.webp" width="100%" alt="Cognee UI 2"></a>
|
||||
Get up and running in minutes with automatic updates, analytics, and enterprise security.
|
||||
|
||||
### Running the UI
|
||||
1. Sign up on [cogwit](https://www.cognee.ai)
|
||||
2. Add your API key to local UI and sync your data to Cogwit
|
||||
|
||||
Try cognee UI by setting LLM_API_KEY and running ``` cognee-cli -ui ``` command on your terminal.
|
||||
|
||||
## Understand our architecture
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="assets/cognee_diagram.png" alt="cognee concept diagram" width="100%" />
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
|
@ -203,22 +210,26 @@ Try cognee UI by setting LLM_API_KEY and running ``` cognee-cli -ui ``` command
|
|||
[cognee with local models](https://github.com/user-attachments/assets/8621d3e8-ecb8-4860-afb2-5594f2ee17db)
|
||||
|
||||
|
||||
## Contributing
|
||||
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
|
||||
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
We are committed to making open source an enjoyable and respectful experience for our community. See <a href="https://github.com/topoteretes/cognee/blob/main/CODE_OF_CONDUCT.md"><code>CODE_OF_CONDUCT</code></a> for more information.
|
||||
|
||||
## 💫 Contributors
|
||||
## Citation
|
||||
|
||||
<a href="https://github.com/topoteretes/cognee/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=topoteretes/cognee"/>
|
||||
</a>
|
||||
We now have a paper you can cite:
|
||||
|
||||
## Sponsors
|
||||
|
||||
Thanks to the following companies for sponsoring the ongoing development of cognee.
|
||||
|
||||
- [GitHub's Secure Open Source Fund](https://resources.github.com/github-secure-open-source-fund/)
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#topoteretes/cognee&Date)
|
||||
```bibtex
|
||||
@misc{markovic2025optimizinginterfaceknowledgegraphs,
|
||||
title={Optimizing the Interface Between Knowledge Graphs and LLMs for Complex Reasoning},
|
||||
author={Vasilije Markovic and Lazar Obradovic and Laszlo Hajdu and Jovan Pavlovic},
|
||||
year={2025},
|
||||
eprint={2505.24478},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI},
|
||||
url={https://arxiv.org/abs/2505.24478},
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -3,10 +3,18 @@
|
|||
import classNames from "classnames";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
import dynamic from "next/dynamic";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
import getColorForNodeType from "./getColorForNodeType";
|
||||
|
||||
// Dynamically import ForceGraph to prevent SSR issues
|
||||
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
||||
ssr: false,
|
||||
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
|
||||
});
|
||||
|
||||
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
interface GraphVisuzaliationProps {
|
||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||
data?: GraphData<NodeObject, LinkObject>;
|
||||
|
|
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
const graphRef = useRef<ForceGraphMethods>();
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined" && data && graphRef.current) {
|
||||
if (data && graphRef.current) {
|
||||
// add collision force
|
||||
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
|
||||
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
|
||||
|
|
@ -209,63 +217,55 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
|
||||
const [graphShape, setGraphShape] = useState<string>();
|
||||
|
||||
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
|
||||
durationMs?: number,
|
||||
padding?: number,
|
||||
nodeFilter?: (node: NodeObject) => boolean
|
||||
) => {
|
||||
if (!graphRef.current) {
|
||||
console.warn("GraphVisualization: graphRef not ready yet");
|
||||
return undefined as any;
|
||||
}
|
||||
|
||||
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
|
||||
};
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
zoomToFit: graphRef.current!.zoomToFit,
|
||||
setGraphShape: setGraphShape,
|
||||
zoomToFit,
|
||||
setGraphShape,
|
||||
}));
|
||||
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
{(data && typeof window !== "undefined") ? (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={300}
|
||||
onDagError={handleDagError}
|
||||
graphData={data}
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={data ? 300 : 100}
|
||||
onDagError={handleDagError}
|
||||
graphData={data || {
|
||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||
}}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={nodeSize}
|
||||
nodeCanvasObject={renderNode}
|
||||
nodeCanvasObjectMode={() => "replace"}
|
||||
nodeLabel="label"
|
||||
nodeRelSize={data ? nodeSize : 20}
|
||||
nodeCanvasObject={data ? renderNode : renderInitialNode}
|
||||
nodeCanvasObjectMode={() => data ? "replace" : "after"}
|
||||
nodeAutoColorBy={data ? undefined : "type"}
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
|
||||
onNodeClick={handleNodeClick}
|
||||
onBackgroundClick={handleBackgroundClick}
|
||||
d3VelocityDecay={0.3}
|
||||
/>
|
||||
) : (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={100}
|
||||
graphData={{
|
||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||
}}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={20}
|
||||
nodeCanvasObject={renderInitialNode}
|
||||
nodeCanvasObjectMode={() => "after"}
|
||||
nodeAutoColorBy="type"
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
/>
|
||||
)}
|
||||
onNodeClick={handleNodeClick}
|
||||
onBackgroundClick={handleBackgroundClick}
|
||||
d3VelocityDecay={data ? 0.3 : undefined}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,15 +89,6 @@ export default function useChat(dataset: Dataset) {
|
|||
}
|
||||
|
||||
|
||||
interface Node {
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface Relationship {
|
||||
relationship_name: string;
|
||||
}
|
||||
|
||||
type InsightMessage = [Node, Relationship, Node];
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: string): string {
|
||||
|
|
@ -106,14 +97,6 @@ function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: strin
|
|||
}
|
||||
|
||||
switch (searchType) {
|
||||
case "INSIGHTS":
|
||||
return systemMessage.map((message: InsightMessage) => {
|
||||
const [node1, relationship, node2] = message;
|
||||
if (node1.name && node2.name) {
|
||||
return `${node1.name} ${relationship.relationship_name} ${node2.name}.`;
|
||||
}
|
||||
return "";
|
||||
}).join("\n");
|
||||
case "SUMMARIES":
|
||||
return systemMessage.map((message: { text: string }) => message.text).join("\n");
|
||||
case "CHUNKS":
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
|
||||
import Link from "next/link";
|
||||
import Image from "next/image";
|
||||
import { useBoolean } from "@/utils";
|
||||
import { useEffect } from "react";
|
||||
import { useBoolean, fetch } from "@/utils";
|
||||
|
||||
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
|
||||
import { CTAButton, GhostButton, IconButton, Modal } from "../elements";
|
||||
import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements";
|
||||
import syncData from "@/modules/cloud/syncData";
|
||||
|
||||
interface HeaderProps {
|
||||
|
|
@ -23,6 +24,12 @@ export default function Header({ user }: HeaderProps) {
|
|||
setFalse: closeSyncModal,
|
||||
} = useBoolean(false);
|
||||
|
||||
const {
|
||||
value: isMCPConnected,
|
||||
setTrue: setMCPConnected,
|
||||
setFalse: setMCPDisconnected,
|
||||
} = useBoolean(false);
|
||||
|
||||
const handleDataSyncConfirm = () => {
|
||||
syncData()
|
||||
.finally(() => {
|
||||
|
|
@ -30,6 +37,19 @@ export default function Header({ user }: HeaderProps) {
|
|||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const checkMCPConnection = () => {
|
||||
fetch.checkMCPHealth()
|
||||
.then(() => setMCPConnected())
|
||||
.catch(() => setMCPDisconnected());
|
||||
};
|
||||
|
||||
checkMCPConnection();
|
||||
const interval = setInterval(checkMCPConnection, 30000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [setMCPConnected, setMCPDisconnected]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
|
||||
|
|
@ -39,6 +59,10 @@ export default function Header({ user }: HeaderProps) {
|
|||
</div>
|
||||
|
||||
<div className="flex flex-row items-center gap-2.5">
|
||||
<Link href="/mcp-status" className="!text-indigo-600 pl-4 pr-4">
|
||||
<StatusDot className="mr-2" isActive={isMCPConnected} />
|
||||
{ isMCPConnected ? "MCP connected" : "MCP disconnected" }
|
||||
</Link>
|
||||
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
|
||||
<CloudIcon />
|
||||
<div>Sync</div>
|
||||
|
|
|
|||
13
cognee-frontend/src/ui/elements/StatusDot.tsx
Normal file
13
cognee-frontend/src/ui/elements/StatusDot.tsx
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import React from "react";
|
||||
|
||||
const StatusDot = ({ isActive, className }: { isActive: boolean, className?: string }) => {
|
||||
return (
|
||||
<span
|
||||
className={`inline-block w-3 h-3 rounded-full ${className} ${
|
||||
isActive ? "bg-green-500" : "bg-red-500"
|
||||
}`}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default StatusDot;
|
||||
|
|
@ -8,5 +8,6 @@ export { default as IconButton } from "./IconButton";
|
|||
export { default as GhostButton } from "./GhostButton";
|
||||
export { default as NeutralButton } from "./NeutralButton";
|
||||
export { default as StatusIndicator } from "./StatusIndicator";
|
||||
export { default as StatusDot } from "./StatusDot";
|
||||
export { default as Accordion } from "./Accordion";
|
||||
export { default as Notebook } from "./Notebook";
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localho
|
|||
|
||||
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
|
||||
|
||||
const mcpApiUrl = process.env.NEXT_PUBLIC_MCP_API_URL || "http://localhost:8001";
|
||||
|
||||
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
|
||||
let accessToken: string | null = null;
|
||||
|
||||
|
|
@ -49,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
|||
)
|
||||
.then((response) => handleServerErrors(response, retry, useCloud))
|
||||
.catch((error) => {
|
||||
// Handle network errors more gracefully
|
||||
if (error.name === 'TypeError' && error.message.includes('fetch')) {
|
||||
return Promise.reject(
|
||||
new Error("Backend server is not responding. Please check if the server is running.")
|
||||
);
|
||||
}
|
||||
|
||||
if (error.detail === undefined) {
|
||||
return Promise.reject(
|
||||
new Error("No connection to the server.")
|
||||
|
|
@ -62,8 +71,31 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
|||
});
|
||||
}
|
||||
|
||||
fetch.checkHealth = () => {
|
||||
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||
fetch.checkHealth = async () => {
|
||||
const maxRetries = 5;
|
||||
const retryDelay = 1000; // 1 second
|
||||
|
||||
for (let i = 0; i < maxRetries; i++) {
|
||||
try {
|
||||
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||
if (response.ok) {
|
||||
return response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If this is the last retry, throw the error
|
||||
if (i === maxRetries - 1) {
|
||||
throw error;
|
||||
}
|
||||
// Wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, retryDelay));
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error("Backend server is not responding after multiple attempts");
|
||||
};
|
||||
|
||||
fetch.checkMCPHealth = () => {
|
||||
return global.fetch(`${mcpApiUrl.replace("/api", "")}/health`);
|
||||
};
|
||||
|
||||
fetch.setApiKey = (newApiKey: string) => {
|
||||
|
|
|
|||
153
cognee-gui.py
153
cognee-gui.py
|
|
@ -1,153 +0,0 @@
|
|||
import sys
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
import cognee
|
||||
from PySide6.QtWidgets import (
|
||||
QApplication,
|
||||
QWidget,
|
||||
QPushButton,
|
||||
QLineEdit,
|
||||
QFileDialog,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QMessageBox,
|
||||
QTextEdit,
|
||||
QProgressDialog,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from qasync import QEventLoop # Import QEventLoop from qasync
|
||||
except ImportError as e:
|
||||
print(
|
||||
"\nPlease install Cognee with optional gui dependencies or manually install missing dependencies.\n"
|
||||
)
|
||||
print("\nTo install with poetry use:")
|
||||
print("\npoetry install -E gui\n")
|
||||
print("\nOr to install with poetry and all dependencies use:")
|
||||
print("\npoetry install --all-extras\n")
|
||||
print("\nTo install with pip use: ")
|
||||
print('\npip install ".[gui]"\n')
|
||||
raise e
|
||||
|
||||
|
||||
class FileSearchApp(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.selected_file = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
# Horizontal layout for file upload and visualization buttons
|
||||
button_layout = QHBoxLayout()
|
||||
|
||||
# Button to open file dialog
|
||||
self.file_button = QPushButton("Upload File to Cognee", parent=self)
|
||||
self.file_button.clicked.connect(self.open_file_dialog)
|
||||
button_layout.addWidget(self.file_button)
|
||||
|
||||
# Button to visualize data
|
||||
self.visualize_button = QPushButton("Visualize Data", parent=self)
|
||||
self.visualize_button.clicked.connect(lambda: asyncio.ensure_future(self.visualize_data()))
|
||||
button_layout.addWidget(self.visualize_button)
|
||||
|
||||
# Label to display selected file path
|
||||
self.file_label = QLabel("No file selected", parent=self)
|
||||
|
||||
# Line edit for search input
|
||||
self.search_input = QLineEdit(parent=self)
|
||||
self.search_input.setPlaceholderText("Enter text to search...")
|
||||
|
||||
# Button to perform search; schedule the async search on click
|
||||
self.search_button = QPushButton("Cognee Search", parent=self)
|
||||
self.search_button.clicked.connect(lambda: asyncio.ensure_future(self._cognee_search()))
|
||||
|
||||
# Text output area for search results
|
||||
self.result_output = QTextEdit(parent=self)
|
||||
self.result_output.setReadOnly(True)
|
||||
self.result_output.setPlaceholderText("Search results will appear here...")
|
||||
|
||||
# Progress dialog
|
||||
self.progress_dialog = QProgressDialog("Processing..", None, 0, 0, parent=self)
|
||||
self.progress_dialog.setWindowModality(Qt.WindowModal)
|
||||
self.progress_dialog.setCancelButton(None) # Remove the cancel button
|
||||
self.progress_dialog.close()
|
||||
|
||||
# Layout setup
|
||||
layout = QVBoxLayout()
|
||||
layout.addLayout(button_layout)
|
||||
layout.addWidget(self.file_label)
|
||||
layout.addWidget(self.search_input)
|
||||
layout.addWidget(self.search_button)
|
||||
layout.addWidget(self.result_output)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setWindowTitle("Cognee")
|
||||
self.resize(500, 300)
|
||||
|
||||
def open_file_dialog(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self, "Select a File", "", "All Files (*.*);;Text Files (*.txt)"
|
||||
)
|
||||
if file_path:
|
||||
self.selected_file = file_path
|
||||
self.file_label.setText(f"Selected: {file_path}")
|
||||
asyncio.ensure_future(self.process_file_async())
|
||||
|
||||
async def process_file_async(self):
|
||||
"""Asynchronously add and process the selected file."""
|
||||
# Disable the entire window
|
||||
self.progress_dialog.show()
|
||||
self.setEnabled(False)
|
||||
try:
|
||||
await cognee.add(self.selected_file)
|
||||
await cognee.cognify()
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "Error", f"File processing failed: {str(e)}")
|
||||
# Once finished, re-enable the window
|
||||
self.setEnabled(True)
|
||||
self.progress_dialog.close()
|
||||
|
||||
async def _cognee_search(self):
|
||||
"""Performs an async search and updates the result output."""
|
||||
# Disable the entire window
|
||||
self.setEnabled(False)
|
||||
self.progress_dialog.show()
|
||||
|
||||
try:
|
||||
search_text = self.search_input.text().strip()
|
||||
result = await cognee.search(query_text=search_text)
|
||||
print(result)
|
||||
# Assuming result is a list-like object; adjust if necessary
|
||||
self.result_output.setText(result[0])
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "Error", f"Search failed: {str(e)}")
|
||||
|
||||
# Once finished, re-enable the window
|
||||
self.setEnabled(True)
|
||||
self.progress_dialog.close()
|
||||
|
||||
async def visualize_data(self):
|
||||
"""Async slot for handling visualize data button press."""
|
||||
import webbrowser
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
html_file = os.path.join(pathlib.Path(__file__).parent, ".data", "graph_visualization.html")
|
||||
await visualize_graph(html_file)
|
||||
webbrowser.open(f"file://{html_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
# Create a qasync event loop and set it as the current event loop
|
||||
loop = QEventLoop(app)
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
window = FileSearchApp()
|
||||
window.show()
|
||||
|
||||
with loop:
|
||||
loop.run_forever()
|
||||
|
|
@ -266,7 +266,7 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
|
|||
|
||||
- **codify**: Analyse a code repository, build a code graph, stores it in memory
|
||||
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, INSIGHTS
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
|
||||
|
||||
- **list_data**: List all datasets and their data items with IDs for deletion operations
|
||||
|
||||
|
|
|
|||
|
|
@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
|||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec cognee --transport stdio --no-migration
|
||||
exec cognee-mcp --transport stdio --no-migration
|
||||
fi
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec cognee --transport stdio --no-migration
|
||||
exec cognee-mcp --transport stdio --no-migration
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.2",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
|
|
@ -36,4 +37,4 @@ dev = [
|
|||
allow-direct-references = true
|
||||
|
||||
[project.scripts]
|
||||
cognee = "src:main"
|
||||
cognee-mcp = "src:main"
|
||||
|
|
@ -1,8 +1,57 @@
|
|||
from .server import main as server_main
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the package."""
|
||||
"""Deprecated main entry point for the package."""
|
||||
import asyncio
|
||||
|
||||
deprecation_notice = """
|
||||
DEPRECATION NOTICE
|
||||
The CLI entry-point used to start the Cognee MCP service has been renamed from
|
||||
"cognee" to "cognee-mcp". Calling the old entry-point will stop working in a
|
||||
future release.
|
||||
|
||||
WHAT YOU NEED TO DO:
|
||||
Locate every place where you launch the MCP process and replace the final
|
||||
argument cognee → cognee-mcp.
|
||||
|
||||
For the example mcpServers block from Cursor shown below the change is:
|
||||
{
|
||||
"mcpServers": {
|
||||
"Cognee": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"--directory",
|
||||
"/path/to/cognee-mcp",
|
||||
"run",
|
||||
"cognee" // <-- CHANGE THIS to "cognee-mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Continuing to use the old "cognee" entry-point will result in failures once it
|
||||
is removed, so please update your configuration and any shell scripts as soon
|
||||
as possible.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"The 'cognee' command for cognee-mcp is deprecated and will be removed in a future version. "
|
||||
"Please use 'cognee-mcp' instead to avoid conflicts with the main cognee library.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
print("⚠️ DEPRECATION WARNING", file=sys.stderr)
|
||||
print(deprecation_notice, file=sys.stderr)
|
||||
|
||||
asyncio.run(server_main())
|
||||
|
||||
|
||||
def main_mcp():
|
||||
"""Clean main entry point for cognee-mcp command."""
|
||||
import asyncio
|
||||
|
||||
asyncio.run(server_main())
|
||||
|
|
|
|||
|
|
@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
|
|||
|
||||
if len(edges_to_save) > 0:
|
||||
await graph_engine.add_edges(edges_to_save)
|
||||
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges_to_save)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
|||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -38,6 +42,53 @@ mcp = FastMCP("Cognee")
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
async def run_sse_with_cors():
|
||||
"""Custom SSE transport with CORS middleware."""
|
||||
sse_app = mcp.sse_app()
|
||||
sse_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
sse_app,
|
||||
host=mcp.settings.host,
|
||||
port=mcp.settings.port,
|
||||
log_level=mcp.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
async def run_http_with_cors():
|
||||
"""Custom HTTP transport with CORS middleware."""
|
||||
http_app = mcp.streamable_http_app()
|
||||
http_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
http_app,
|
||||
host=mcp.settings.host,
|
||||
port=mcp.settings.port,
|
||||
log_level=mcp.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognee_add_developer_rules(
|
||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||
|
|
@ -204,7 +255,7 @@ async def cognify(
|
|||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
|
|
@ -427,11 +478,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
Best for: Direct document retrieval, specific fact-finding.
|
||||
Returns: LLM responses based on relevant text chunks.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
Returns: Formatted relationship data and entity connections.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -473,7 +519,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
- "RAG_COMPLETION": Returns an LLM response based on the search query and standard RAG data
|
||||
- "CODE": Returns code-related knowledge in JSON format
|
||||
- "CHUNKS": Returns raw text chunks from the knowledge graph
|
||||
- "INSIGHTS": Returns relationships between nodes in readable format
|
||||
- "SUMMARIES": Returns pre-generated hierarchical summaries
|
||||
- "CYPHER": Direct graph database queries
|
||||
- "FEELING_LUCKY": Automatically selects best search type
|
||||
|
|
@ -486,7 +531,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
A list containing a single TextContent object with the search results.
|
||||
The format of the result depends on the search_type:
|
||||
- **GRAPH_COMPLETION/RAG_COMPLETION**: Conversational AI response strings
|
||||
- **INSIGHTS**: Formatted relationship descriptions and entity connections
|
||||
- **CHUNKS**: Relevant text passages with source metadata
|
||||
- **SUMMARIES**: Hierarchical summaries from general to specific
|
||||
- **CODE**: Structured code information with context
|
||||
|
|
@ -496,7 +540,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
Performance & Optimization:
|
||||
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
|
||||
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
|
||||
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
|
||||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
|
|
@ -535,9 +578,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
return str(search_results[0])
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
else:
|
||||
return str(search_results)
|
||||
|
||||
|
|
@ -975,12 +1015,12 @@ async def main():
|
|||
await mcp.run_stdio_async()
|
||||
elif args.transport == "sse":
|
||||
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
||||
await mcp.run_sse_async()
|
||||
await run_sse_with_cors()
|
||||
elif args.transport == "http":
|
||||
logger.info(
|
||||
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
||||
)
|
||||
await mcp.run_streamable_http_async()
|
||||
await run_http_with_cors()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
3529
cognee-mcp/uv.lock
generated
3529
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -19,6 +19,7 @@ from .api.v1.add import add
|
|||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .modules.memify import memify
|
||||
from .api.v1.update import update
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.prune import prune
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from cognee.api.v1.add.routers import get_add_router
|
|||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
from cognee.api.v1.responses.routers import get_responses_router
|
||||
from cognee.api.v1.sync.routers import get_sync_router
|
||||
from cognee.api.v1.update.routers import get_update_router
|
||||
from cognee.api.v1.users.routers import (
|
||||
get_auth_router,
|
||||
get_register_router,
|
||||
|
|
@ -263,6 +264,8 @@ app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["vi
|
|||
|
||||
app.include_router(get_delete_router(), prefix="/api/v1/delete", tags=["delete"])
|
||||
|
||||
app.include_router(get_update_router(), prefix="/api/v1/update", tags=["update"])
|
||||
|
||||
app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["responses"])
|
||||
|
||||
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
|
||||
|
|
|
|||
|
|
@ -189,12 +189,12 @@ class HealthChecker:
|
|||
start_time = time.time()
|
||||
try:
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
# Test actual API connection with minimal request
|
||||
LLMGateway.show_prompt("test", "test")
|
||||
from cognee.infrastructure.llm.utils import test_llm_connection
|
||||
|
||||
await test_llm_connection()
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
return ComponentHealth(
|
||||
|
|
@ -217,13 +217,9 @@ class HealthChecker:
|
|||
"""Check embedding service health (non-critical)."""
|
||||
start_time = time.time()
|
||||
try:
|
||||
from cognee.infrastructure.databases.vector.embeddings.get_embedding_engine import (
|
||||
get_embedding_engine,
|
||||
)
|
||||
from cognee.infrastructure.llm.utils import test_embedding_connection
|
||||
|
||||
# Test actual embedding generation with minimal text
|
||||
engine = get_embedding_engine()
|
||||
await engine.embed_text(["test"])
|
||||
await test_embedding_connection()
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
return ComponentHealth(
|
||||
|
|
@ -245,16 +241,6 @@ class HealthChecker:
|
|||
"""Get comprehensive health status."""
|
||||
components = {}
|
||||
|
||||
# Critical services
|
||||
critical_components = [
|
||||
"relational_db",
|
||||
"vector_db",
|
||||
"graph_db",
|
||||
"file_storage",
|
||||
"llm_provider",
|
||||
"embedding_service",
|
||||
]
|
||||
|
||||
critical_checks = [
|
||||
("relational_db", self.check_relational_db()),
|
||||
("vector_db", self.check_vector_db()),
|
||||
|
|
@ -300,11 +286,11 @@ class HealthChecker:
|
|||
else:
|
||||
components[name] = result
|
||||
|
||||
critical_comps = [check[0] for check in critical_checks]
|
||||
# Determine overall status
|
||||
critical_unhealthy = any(
|
||||
comp.status == HealthStatus.UNHEALTHY
|
||||
comp.status == HealthStatus.UNHEALTHY and name in critical_comps
|
||||
for name, comp in components.items()
|
||||
if name in critical_components
|
||||
)
|
||||
|
||||
has_degraded = any(comp.status == HealthStatus.DEGRADED for comp in components.values())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, List, Optional
|
||||
|
||||
import os
|
||||
from typing import Union, BinaryIO, List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from urllib.parse import urlparse
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import Task, run_pipeline
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
|
||||
|
|
@ -11,6 +13,19 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|||
)
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
|
||||
from cognee.context_global_variables import (
|
||||
tavily_config as tavily,
|
||||
soup_crawler_config as soup_crawler,
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
|
||||
|
||||
async def add(
|
||||
|
|
@ -23,12 +38,15 @@ async def add(
|
|||
dataset_id: Optional[UUID] = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
incremental_loading: bool = True,
|
||||
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||
tavily_config: Optional[BaseModel] = None,
|
||||
soup_crawler_config: Optional[BaseModel] = None,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
|
||||
This is the first step in the Cognee workflow - it ingests raw data and prepares it
|
||||
for processing. The function accepts various data formats including text, files, and
|
||||
for processing. The function accepts various data formats including text, files, urls and
|
||||
binary streams, then stores them in a specified dataset for further processing.
|
||||
|
||||
Prerequisites:
|
||||
|
|
@ -68,6 +86,7 @@ async def add(
|
|||
- S3 path: "s3://my-bucket/documents/file.pdf"
|
||||
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
|
||||
- Binary file object: open("file.txt", "rb")
|
||||
- url: A web link url (https or http)
|
||||
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
|
||||
Create separate datasets to organize different knowledge domains.
|
||||
user: User object for authentication and permissions. Uses default user if None.
|
||||
|
|
@ -78,6 +97,9 @@ async def add(
|
|||
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||
extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup
|
||||
tavily_config: Optional configuration for Tavily API, including API key and extraction settings
|
||||
soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules.
|
||||
|
||||
Returns:
|
||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||
|
|
@ -126,6 +148,21 @@ async def add(
|
|||
|
||||
# Add a single file
|
||||
await cognee.add("/home/user/documents/analysis.pdf")
|
||||
|
||||
# Add a single url and bs4 extract ingestion method
|
||||
extraction_rules = {
|
||||
"title": "h1",
|
||||
"description": "p",
|
||||
"more_info": "a[href*='more-info']"
|
||||
}
|
||||
await cognee.add("https://example.com",extraction_rules=extraction_rules)
|
||||
|
||||
# Add a single url and tavily extract ingestion method
|
||||
Make sure to set TAVILY_API_KEY = YOUR_TAVILY_API_KEY as a environment variable
|
||||
await cognee.add("https://example.com")
|
||||
|
||||
# Add multiple urls
|
||||
await cognee.add(["https://example.com","https://books.toscrape.com"])
|
||||
```
|
||||
|
||||
Environment Variables:
|
||||
|
|
@ -133,22 +170,55 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
- VECTOR_DB_PROVIDER: "lancedb" (default), "chromadb", "pgvector"
|
||||
- GRAPH_DATABASE_PROVIDER: "kuzu" (default), "neo4j"
|
||||
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
if not soup_crawler_config and extraction_rules:
|
||||
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
|
||||
if not tavily_config and os.getenv("TAVILY_API_KEY"):
|
||||
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
soup_crawler.set(soup_crawler_config)
|
||||
tavily.set(tavily_config)
|
||||
|
||||
http_schemes = {"http", "https"}
|
||||
|
||||
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
|
||||
return isinstance(item, str) and urlparse(item).scheme in http_schemes
|
||||
|
||||
if _is_http_url(data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
except NameError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
Task(ingest_data, dataset_name, user, node_set, dataset_id, preferred_loaders),
|
||||
Task(
|
||||
ingest_data,
|
||||
dataset_name,
|
||||
user,
|
||||
node_set,
|
||||
dataset_id,
|
||||
preferred_loaders,
|
||||
),
|
||||
]
|
||||
|
||||
await setup()
|
||||
|
||||
user, authorized_dataset = await resolve_authorized_user_dataset(dataset_id, dataset_name, user)
|
||||
user, authorized_dataset = await resolve_authorized_user_dataset(
|
||||
dataset_name=dataset_name, dataset_id=dataset_id, user=user
|
||||
)
|
||||
|
||||
await reset_dataset_pipeline_run_status(
|
||||
authorized_dataset.id, user, pipeline_names=["add_pipeline", "cognify_pipeline"]
|
||||
|
|
|
|||
|
|
@ -73,7 +73,11 @@ def get_add_router() -> APIRouter:
|
|||
|
||||
try:
|
||||
add_run = await cognee_add(
|
||||
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
|
||||
data,
|
||||
datasetName,
|
||||
user=user,
|
||||
dataset_id=datasetId,
|
||||
node_set=node_set if node_set else None,
|
||||
)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ async def cognify(
|
|||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
|
|
|
|||
|
|
@ -94,9 +94,11 @@ def get_permissions_router() -> APIRouter:
|
|||
|
||||
from cognee.modules.users.roles.methods import create_role as create_role_method
|
||||
|
||||
await create_role_method(role_name=role_name, owner_id=user.id)
|
||||
role_id = await create_role_method(role_name=role_name, owner_id=user.id)
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Role created for tenant"})
|
||||
return JSONResponse(
|
||||
status_code=200, content={"message": "Role created for tenant", "role_id": str(role_id)}
|
||||
)
|
||||
|
||||
@permissions_router.post("/users/{user_id}/roles")
|
||||
async def add_user_to_role(
|
||||
|
|
@ -212,8 +214,10 @@ def get_permissions_router() -> APIRouter:
|
|||
|
||||
from cognee.modules.users.tenants.methods import create_tenant as create_tenant_method
|
||||
|
||||
await create_tenant_method(tenant_name=tenant_name, user_id=user.id)
|
||||
tenant_id = await create_tenant_method(tenant_name=tenant_name, user_id=user.id)
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Tenant created."})
|
||||
return JSONResponse(
|
||||
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
|
||||
)
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
|
|||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"NATURAL_LANGUAGE",
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ async def handle_search(arguments: Dict[str, Any], user) -> list:
|
|||
valid_search_types = (
|
||||
search_tool["parameters"]["properties"]["search_type"]["enum"]
|
||||
if search_tool
|
||||
else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
|
||||
else ["CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
|
||||
)
|
||||
|
||||
if search_type_str not in valid_search_types:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
|
|||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"NATURAL_LANGUAGE",
|
||||
|
|
|
|||
|
|
@ -52,11 +52,6 @@ async def search(
|
|||
Best for: Direct document retrieval, specific fact-finding.
|
||||
Returns: LLM responses based on relevant text chunks.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
Returns: Formatted relationship data and entity connections.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -124,9 +119,6 @@ async def search(
|
|||
**GRAPH_COMPLETION/RAG_COMPLETION**:
|
||||
[List of conversational AI response strings]
|
||||
|
||||
**INSIGHTS**:
|
||||
[List of formatted relationship descriptions and entity connections]
|
||||
|
||||
**CHUNKS**:
|
||||
[List of relevant text passages with source metadata]
|
||||
|
||||
|
|
@ -146,7 +138,6 @@ async def search(
|
|||
Performance & Optimization:
|
||||
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
|
||||
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
|
||||
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
|
||||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
|
|
|
|||
|
|
@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
|
|||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
||||
provider: Union[
|
||||
Literal["openai"],
|
||||
Literal["ollama"],
|
||||
Literal["anthropic"],
|
||||
Literal["gemini"],
|
||||
Literal["mistral"],
|
||||
]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .ui import start_ui, stop_ui, ui
|
||||
from .ui import start_ui
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
|
@ -7,7 +9,7 @@ import webbrowser
|
|||
import zipfile
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple, List
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
|
@ -17,6 +19,80 @@ from cognee.version import get_cognee_version
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
def _stream_process_output(
|
||||
process: subprocess.Popen, stream_name: str, prefix: str, color_code: str = ""
|
||||
) -> threading.Thread:
|
||||
"""
|
||||
Stream output from a process with a prefix to identify the source.
|
||||
|
||||
Args:
|
||||
process: The subprocess to monitor
|
||||
stream_name: 'stdout' or 'stderr'
|
||||
prefix: Text prefix for each line (e.g., '[BACKEND]', '[FRONTEND]')
|
||||
color_code: ANSI color code for the prefix (optional)
|
||||
|
||||
Returns:
|
||||
Thread that handles the streaming
|
||||
"""
|
||||
|
||||
def stream_reader():
|
||||
stream = getattr(process, stream_name)
|
||||
if stream is None:
|
||||
return
|
||||
|
||||
reset_code = "\033[0m" if color_code else ""
|
||||
|
||||
try:
|
||||
for line in iter(stream.readline, b""):
|
||||
if line:
|
||||
line_text = line.decode("utf-8").rstrip()
|
||||
if line_text:
|
||||
print(f"{color_code}{prefix}{reset_code} {line_text}", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if stream:
|
||||
stream.close()
|
||||
|
||||
thread = threading.Thread(target=stream_reader, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
def _is_port_available(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is available on localhost.
|
||||
Returns True if the port is available, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(1) # 1 second timeout
|
||||
result = sock.connect_ex(("localhost", port))
|
||||
return result != 0 # Port is available if connection fails
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_required_ports(ports_to_check: List[Tuple[int, str]]) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Check if all required ports are available on localhost.
|
||||
|
||||
Args:
|
||||
ports_to_check: List of (port, service_name) tuples
|
||||
|
||||
Returns:
|
||||
Tuple of (all_available: bool, unavailable_services: List[str])
|
||||
"""
|
||||
unavailable = []
|
||||
|
||||
for port, service_name in ports_to_check:
|
||||
if not _is_port_available(port):
|
||||
unavailable.append(f"{service_name} (port {port})")
|
||||
logger.error(f"Port {port} is already in use for {service_name}")
|
||||
|
||||
return len(unavailable) == 0, unavailable
|
||||
|
||||
|
||||
def normalize_version_for_comparison(version: str) -> str:
|
||||
"""
|
||||
Normalize version string for comparison.
|
||||
|
|
@ -214,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
|
|||
Check if Node.js and npm are available.
|
||||
Returns (is_available, error_message)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
|
|
@ -223,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
|
|||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm
|
||||
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10)
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
|
|
@ -246,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
Install frontend dependencies if node_modules doesn't exist.
|
||||
This is needed for both development and downloaded frontends since both use npm run dev.
|
||||
"""
|
||||
|
||||
node_modules = frontend_path / "node_modules"
|
||||
if node_modules.exists():
|
||||
logger.debug("Frontend dependencies already installed")
|
||||
|
|
@ -254,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Frontend dependencies installed successfully")
|
||||
|
|
@ -327,55 +425,113 @@ def prompt_user_for_download() -> bool:
|
|||
|
||||
def start_ui(
|
||||
pid_callback: Callable[[int], None],
|
||||
host: str = "localhost",
|
||||
port: int = 3000,
|
||||
open_browser: bool = True,
|
||||
auto_download: bool = False,
|
||||
start_backend: bool = False,
|
||||
backend_host: str = "localhost",
|
||||
backend_port: int = 8000,
|
||||
start_mcp: bool = False,
|
||||
mcp_port: int = 8001,
|
||||
) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Start the cognee frontend UI server, optionally with the backend API server.
|
||||
Start the cognee frontend UI server, optionally with the backend API server and MCP server.
|
||||
|
||||
This function will:
|
||||
1. Optionally start the cognee backend API server
|
||||
2. Find the cognee-frontend directory (development) or download it (pip install)
|
||||
3. Check if Node.js and npm are available (for development mode)
|
||||
4. Install dependencies if needed (development mode)
|
||||
5. Start the frontend server
|
||||
6. Optionally open the browser
|
||||
2. Optionally start the cognee MCP server
|
||||
3. Find the cognee-frontend directory (development) or download it (pip install)
|
||||
4. Check if Node.js and npm are available (for development mode)
|
||||
5. Install dependencies if needed (development mode)
|
||||
6. Start the frontend server
|
||||
7. Optionally open the browser
|
||||
|
||||
Args:
|
||||
pid_callback: Callback to notify with PID of each spawned process
|
||||
host: Host to bind the frontend server to (default: localhost)
|
||||
port: Port to run the frontend server on (default: 3000)
|
||||
open_browser: Whether to open the browser automatically (default: True)
|
||||
auto_download: If True, download frontend without prompting (default: False)
|
||||
start_backend: If True, also start the cognee API backend server (default: False)
|
||||
backend_host: Host to bind the backend server to (default: localhost)
|
||||
backend_port: Port to run the backend server on (default: 8000)
|
||||
start_mcp: If True, also start the cognee MCP server (default: False)
|
||||
mcp_port: Port to run the MCP server on (default: 8001)
|
||||
|
||||
Returns:
|
||||
subprocess.Popen object representing the running frontend server, or None if failed
|
||||
Note: If backend is started, it runs in a separate process that will be cleaned up
|
||||
when the frontend process is terminated.
|
||||
Note: If backend and/or MCP server are started, they run in separate processes
|
||||
that will be cleaned up when the frontend process is terminated.
|
||||
|
||||
Example:
|
||||
>>> import cognee
|
||||
>>> def dummy_callback(pid): pass
|
||||
>>> # Start just the frontend
|
||||
>>> server = cognee.start_ui()
|
||||
>>> server = cognee.start_ui(dummy_callback)
|
||||
>>>
|
||||
>>> # Start both frontend and backend
|
||||
>>> server = cognee.start_ui(start_backend=True)
|
||||
>>> server = cognee.start_ui(dummy_callback, start_backend=True)
|
||||
>>> # UI will be available at http://localhost:3000
|
||||
>>> # API will be available at http://localhost:8000
|
||||
>>> # To stop both servers later:
|
||||
>>>
|
||||
>>> # Start frontend with MCP server
|
||||
>>> server = cognee.start_ui(dummy_callback, start_mcp=True)
|
||||
>>> # UI will be available at http://localhost:3000
|
||||
>>> # MCP server will be available at http://127.0.0.1:8001/sse
|
||||
>>> # To stop all servers later:
|
||||
>>> server.terminate()
|
||||
"""
|
||||
logger.info("Starting cognee UI...")
|
||||
|
||||
ports_to_check = [(port, "Frontend UI")]
|
||||
|
||||
if start_backend:
|
||||
ports_to_check.append((backend_port, "Backend API"))
|
||||
|
||||
if start_mcp:
|
||||
ports_to_check.append((mcp_port, "MCP Server"))
|
||||
|
||||
logger.info("Checking port availability...")
|
||||
all_ports_available, unavailable_services = _check_required_ports(ports_to_check)
|
||||
|
||||
if not all_ports_available:
|
||||
error_msg = f"Cannot start cognee UI: The following services have ports already in use: {', '.join(unavailable_services)}"
|
||||
logger.error(error_msg)
|
||||
logger.error("Please stop the conflicting services or change the port configuration.")
|
||||
return None
|
||||
|
||||
logger.info("✓ All required ports are available")
|
||||
backend_process = None
|
||||
|
||||
if start_mcp:
|
||||
logger.info("Starting Cognee MCP server with Docker...")
|
||||
cwd = os.getcwd()
|
||||
env_file = os.path.join(cwd, ".env")
|
||||
try:
|
||||
image = "cognee/cognee-mcp:main"
|
||||
subprocess.run(["docker", "pull", image], check=True)
|
||||
mcp_process = subprocess.Popen(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-p",
|
||||
f"{mcp_port}:8000",
|
||||
"--rm",
|
||||
"--env-file",
|
||||
env_file,
|
||||
"-e",
|
||||
"TRANSPORT_MODE=sse",
|
||||
"cognee/cognee-mcp:main",
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
|
||||
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
|
||||
|
||||
pid_callback(mcp_process.pid)
|
||||
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
|
||||
# Start backend server if requested
|
||||
if start_backend:
|
||||
logger.info("Starting cognee backend API server...")
|
||||
|
|
@ -389,16 +545,19 @@ def start_ui(
|
|||
"uvicorn",
|
||||
"cognee.api.client:app",
|
||||
"--host",
|
||||
backend_host,
|
||||
"localhost",
|
||||
"--port",
|
||||
str(backend_port),
|
||||
],
|
||||
# Inherit stdout/stderr from parent process to show logs
|
||||
stdout=None,
|
||||
stderr=None,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream backend output with prefix
|
||||
_stream_process_output(backend_process, "stdout", "[BACKEND]", "\033[32m") # Green
|
||||
_stream_process_output(backend_process, "stderr", "[BACKEND]", "\033[32m") # Green
|
||||
|
||||
pid_callback(backend_process.pid)
|
||||
|
||||
# Give the backend a moment to start
|
||||
|
|
@ -408,7 +567,7 @@ def start_ui(
|
|||
logger.error("Backend server failed to start - process exited early")
|
||||
return None
|
||||
|
||||
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
|
||||
logger.info(f"✓ Backend API started at http://localhost:{backend_port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start backend server: {str(e)}")
|
||||
|
|
@ -453,24 +612,38 @@ def start_ui(
|
|||
|
||||
# Prepare environment variables
|
||||
env = os.environ.copy()
|
||||
env["HOST"] = host
|
||||
env["HOST"] = "localhost"
|
||||
env["PORT"] = str(port)
|
||||
|
||||
# Start the development server
|
||||
logger.info(f"Starting frontend server at http://{host}:{port}")
|
||||
logger.info(f"Starting frontend server at http://localhost:{port}")
|
||||
logger.info("This may take a moment to compile and start...")
|
||||
|
||||
try:
|
||||
# Create frontend in its own process group for clean termination
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream frontend output with prefix
|
||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||
_stream_process_output(process, "stderr", "[FRONTEND]", "\033[33m") # Yellow
|
||||
|
||||
pid_callback(process.pid)
|
||||
|
||||
|
|
@ -479,10 +652,7 @@ def start_ui(
|
|||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
stdout, stderr = process.communicate()
|
||||
logger.error("Frontend server failed to start:")
|
||||
logger.error(f"stdout: {stdout}")
|
||||
logger.error(f"stderr: {stderr}")
|
||||
logger.error("Frontend server failed to start - check the logs above for details")
|
||||
return None
|
||||
|
||||
# Open browser if requested
|
||||
|
|
@ -491,7 +661,7 @@ def start_ui(
|
|||
def open_browser_delayed():
|
||||
time.sleep(5) # Give Next.js time to fully start
|
||||
try:
|
||||
webbrowser.open(f"http://{host}:{port}") # TODO: use dashboard url?
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not open browser automatically: {e}")
|
||||
|
||||
|
|
@ -499,13 +669,9 @@ def start_ui(
|
|||
browser_thread.start()
|
||||
|
||||
logger.info("✓ Cognee UI is starting up...")
|
||||
logger.info(f"✓ Open your browser to: http://{host}:{port}")
|
||||
logger.info(f"✓ Open your browser to: http://localhost:{port}")
|
||||
logger.info("✓ The UI will be available once Next.js finishes compiling")
|
||||
|
||||
# Store backend process reference in the frontend process for cleanup
|
||||
if backend_process:
|
||||
process._cognee_backend_process = backend_process
|
||||
|
||||
return process
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -523,102 +689,3 @@ def start_ui(
|
|||
except (OSError, ProcessLookupError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def stop_ui(process: subprocess.Popen) -> bool:
|
||||
"""
|
||||
Stop a running UI server process and backend process (if started), along with all their children.
|
||||
|
||||
Args:
|
||||
process: The subprocess.Popen object returned by start_ui()
|
||||
|
||||
Returns:
|
||||
bool: True if stopped successfully, False otherwise
|
||||
"""
|
||||
if not process:
|
||||
return False
|
||||
|
||||
success = True
|
||||
|
||||
try:
|
||||
# First, stop the backend process if it exists
|
||||
backend_process = getattr(process, "_cognee_backend_process", None)
|
||||
if backend_process:
|
||||
logger.info("Stopping backend server...")
|
||||
try:
|
||||
backend_process.terminate()
|
||||
try:
|
||||
backend_process.wait(timeout=5)
|
||||
logger.info("Backend server stopped gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Backend didn't terminate gracefully, forcing kill")
|
||||
backend_process.kill()
|
||||
backend_process.wait()
|
||||
logger.info("Backend server stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping backend server: {str(e)}")
|
||||
success = False
|
||||
|
||||
# Now stop the frontend process
|
||||
logger.info("Stopping frontend server...")
|
||||
# Try to terminate the process group (includes child processes like Next.js)
|
||||
if hasattr(os, "killpg"):
|
||||
try:
|
||||
# Kill the entire process group
|
||||
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
||||
logger.debug("Sent SIGTERM to process group")
|
||||
except (OSError, ProcessLookupError):
|
||||
# Fall back to terminating just the main process
|
||||
process.terminate()
|
||||
logger.debug("Terminated main process only")
|
||||
else:
|
||||
process.terminate()
|
||||
logger.debug("Terminated main process (Windows)")
|
||||
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
logger.info("Frontend server stopped gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Frontend didn't terminate gracefully, forcing kill")
|
||||
|
||||
# Force kill the process group
|
||||
if hasattr(os, "killpg"):
|
||||
try:
|
||||
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
|
||||
logger.debug("Sent SIGKILL to process group")
|
||||
except (OSError, ProcessLookupError):
|
||||
process.kill()
|
||||
logger.debug("Force killed main process only")
|
||||
else:
|
||||
process.kill()
|
||||
logger.debug("Force killed main process (Windows)")
|
||||
|
||||
process.wait()
|
||||
|
||||
if success:
|
||||
logger.info("UI servers stopped successfully")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping UI servers: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# Convenience function similar to DuckDB's approach
|
||||
def ui() -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Convenient alias for start_ui() with default parameters.
|
||||
Similar to how DuckDB provides simple ui() function.
|
||||
"""
|
||||
return start_ui()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the UI startup
|
||||
server = start_ui()
|
||||
if server:
|
||||
try:
|
||||
input("Press Enter to stop the server...")
|
||||
finally:
|
||||
stop_ui(server)
|
||||
|
|
|
|||
1
cognee/api/v1/update/__init__.py
Normal file
1
cognee/api/v1/update/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .update import update
|
||||
1
cognee/api/v1/update/routers/__init__.py
Normal file
1
cognee/api/v1/update/routers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .get_update_router import get_update_router
|
||||
90
cognee/api/v1/update/routers/get_update_router.py
Normal file
90
cognee/api/v1/update/routers/get_update_router.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi import File, UploadFile, Depends, Form
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunErrored,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_update_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.patch("", response_model=None)
|
||||
async def update(
|
||||
data_id: UUID,
|
||||
dataset_id: UUID,
|
||||
data: List[UploadFile] = File(default=None),
|
||||
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Update data in a dataset.
|
||||
|
||||
This endpoint updates existing documents in a specified dataset by providing the data_id of the existing document
|
||||
to update and the new document with the changes as the data.
|
||||
The document is updated, analyzed, and the changes are integrated into the knowledge graph.
|
||||
|
||||
## Request Parameters
|
||||
- **data_id** (UUID): UUID of the document to update in Cognee memory
|
||||
- **data** (List[UploadFile]): List of files to upload.
|
||||
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
|
||||
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
|
||||
Used for grouping related data points in the knowledge graph.
|
||||
|
||||
## Response
|
||||
Returns information about the add operation containing:
|
||||
- Status of the operation
|
||||
- Details about the processed data
|
||||
- Any relevant metadata from the ingestion process
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Neither datasetId nor datasetName provided
|
||||
- **409 Conflict**: Error during add operation
|
||||
- **403 Forbidden**: User doesn't have permission to add to dataset
|
||||
|
||||
## Notes
|
||||
- To add data to datasets not owned by the user, use dataset_id (when ENABLE_BACKEND_ACCESS_CONTROL is set to True)
|
||||
- datasetId value can only be the UUID of an already existing dataset
|
||||
"""
|
||||
send_telemetry(
|
||||
"Update API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "PATCH /v1/update",
|
||||
"dataset_id": str(dataset_id),
|
||||
"data_id": str(data_id),
|
||||
"node_set": str(node_set),
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.api.v1.update import update as cognee_update
|
||||
|
||||
try:
|
||||
update_run = await cognee_update(
|
||||
data_id=data_id,
|
||||
data=data,
|
||||
dataset_id=dataset_id,
|
||||
user=user,
|
||||
node_set=node_set if node_set else None,
|
||||
)
|
||||
|
||||
# If any cognify run errored return JSONResponse with proper error status code
|
||||
if any(isinstance(v, PipelineRunErrored) for v in update_run.values()):
|
||||
return JSONResponse(status_code=420, content=jsonable_encoder(update_run))
|
||||
return update_run
|
||||
|
||||
except Exception as error:
|
||||
logger.error(f"Error during deletion by data_id: {str(error)}")
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
100
cognee/api/v1/update/update.py
Normal file
100
cognee/api/v1/update/update.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, List, Optional
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.api.v1.delete import delete
|
||||
from cognee.api.v1.add import add
|
||||
from cognee.api.v1.cognify import cognify
|
||||
|
||||
|
||||
async def update(
|
||||
data_id: UUID,
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
dataset_id: UUID,
|
||||
user: User = None,
|
||||
node_set: Optional[List[str]] = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
"""
|
||||
Update existing data in Cognee.
|
||||
|
||||
Supported Input Types:
|
||||
- **Text strings**: Direct text content (str) - any string not starting with "/" or "file://"
|
||||
- **File paths**: Local file paths as strings in these formats:
|
||||
* Absolute paths: "/path/to/document.pdf"
|
||||
* File URLs: "file:///path/to/document.pdf" or "file://relative/path.txt"
|
||||
* S3 paths: "s3://bucket-name/path/to/file.pdf"
|
||||
- **Binary file objects**: File handles/streams (BinaryIO)
|
||||
- **Lists**: Multiple files or text strings in a single call
|
||||
|
||||
Supported File Formats:
|
||||
- Text files (.txt, .md, .csv)
|
||||
- PDFs (.pdf)
|
||||
- Images (.png, .jpg, .jpeg) - extracted via OCR/vision models
|
||||
- Audio files (.mp3, .wav) - transcribed to text
|
||||
- Code files (.py, .js, .ts, etc.) - parsed for structure and content
|
||||
- Office documents (.docx, .pptx)
|
||||
|
||||
Workflow:
|
||||
1. **Data Resolution**: Resolves file paths and validates accessibility
|
||||
2. **Content Extraction**: Extracts text content from various file formats
|
||||
3. **Dataset Storage**: Stores processed content in the specified dataset
|
||||
4. **Metadata Tracking**: Records file metadata, timestamps, and user permissions
|
||||
5. **Permission Assignment**: Grants user read/write/delete/share permissions on dataset
|
||||
|
||||
Args:
|
||||
data_id: UUID of existing data to update
|
||||
data: The latest version of the data. Can be:
|
||||
- Single text string: "Your text content here"
|
||||
- Absolute file path: "/path/to/document.pdf"
|
||||
- File URL: "file:///absolute/path/to/document.pdf" or "file://relative/path.txt"
|
||||
- S3 path: "s3://my-bucket/documents/file.pdf"
|
||||
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
|
||||
- Binary file object: open("file.txt", "rb")
|
||||
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
|
||||
Create separate datasets to organize different knowledge domains.
|
||||
user: User object for authentication and permissions. Uses default user if None.
|
||||
Default user: "default_user@example.com" (created automatically on first use).
|
||||
Users can only access datasets they have permissions for.
|
||||
node_set: Optional list of node identifiers for graph organization and access control.
|
||||
Used for grouping related data points in the knowledge graph.
|
||||
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||
|
||||
Returns:
|
||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||
- Pipeline run ID for tracking
|
||||
- Dataset ID where data was stored
|
||||
- Processing status and any errors
|
||||
- Execution timestamps and metadata
|
||||
"""
|
||||
await delete(
|
||||
data_id=data_id,
|
||||
dataset_id=dataset_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
await add(
|
||||
data=data,
|
||||
dataset_id=dataset_id,
|
||||
user=user,
|
||||
node_set=node_set,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
preferred_loaders=preferred_loaders,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
cognify_run = await cognify(
|
||||
datasets=[dataset_id],
|
||||
user=user,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
return cognify_run
|
||||
|
|
@ -183,10 +183,20 @@ def main() -> int:
|
|||
|
||||
for pid in spawned_pids:
|
||||
try:
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
except (OSError, ProcessLookupError) as e:
|
||||
if hasattr(os, "killpg"):
|
||||
# Unix-like systems: Use process groups
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
else:
|
||||
# Windows: Use taskkill to terminate process and its children
|
||||
subprocess.run(
|
||||
["taskkill", "/F", "/T", "/PID", str(pid)],
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
fmt.success(f"✓ Process {pid} and its children terminated.")
|
||||
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
|
||||
fmt.warning(f"Could not terminate process {pid}: {e}")
|
||||
|
||||
sys.exit(0)
|
||||
|
|
@ -204,19 +214,27 @@ def main() -> int:
|
|||
nonlocal spawned_pids
|
||||
spawned_pids.append(pid)
|
||||
|
||||
frontend_port = 3000
|
||||
start_backend, backend_port = True, 8000
|
||||
start_mcp, mcp_port = True, 8001
|
||||
server_process = start_ui(
|
||||
host="localhost",
|
||||
port=3000,
|
||||
open_browser=True,
|
||||
start_backend=True,
|
||||
auto_download=True,
|
||||
pid_callback=pid_callback,
|
||||
port=frontend_port,
|
||||
open_browser=True,
|
||||
auto_download=True,
|
||||
start_backend=start_backend,
|
||||
backend_port=backend_port,
|
||||
start_mcp=start_mcp,
|
||||
mcp_port=mcp_port,
|
||||
)
|
||||
|
||||
if server_process:
|
||||
fmt.success("UI server started successfully!")
|
||||
fmt.echo("The interface is available at: http://localhost:3000")
|
||||
fmt.echo("The API backend is available at: http://localhost:8000")
|
||||
fmt.echo(f"The interface is available at: http://localhost:{frontend_port}")
|
||||
if start_backend:
|
||||
fmt.echo(f"The API backend is available at: http://localhost:{backend_port}")
|
||||
if start_mcp:
|
||||
fmt.echo(f"The MCP server is available at: http://localhost:{mcp_port}")
|
||||
fmt.note("Press Ctrl+C to stop the server...")
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -70,11 +70,11 @@ After adding data, use `cognee cognify` to process it into knowledge graphs.
|
|||
await cognee.add(data=data_to_add, dataset_name=args.dataset_name)
|
||||
fmt.success(f"Successfully added data to dataset '{args.dataset_name}'")
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to add data: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to add data: {str(e)}") from e
|
||||
|
||||
asyncio.run(run_add())
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error adding data: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Failed to add data: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to cognify: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to cognify: {str(e)}") from e
|
||||
|
||||
result = asyncio.run(run_cognify())
|
||||
|
||||
|
|
@ -124,5 +124,5 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -79,8 +79,10 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error managing configuration: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(
|
||||
f"Error managing configuration: {str(e)}", error_code=1
|
||||
) from e
|
||||
|
||||
def _handle_get(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -122,7 +124,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.note("Configuration viewing not fully implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}") from e
|
||||
|
||||
def _handle_set(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -141,7 +143,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.error(f"Failed to set configuration key '{args.key}'")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}") from e
|
||||
|
||||
def _handle_unset(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -189,7 +191,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.note("Use 'cognee config list' to see all available configuration options")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}") from e
|
||||
|
||||
def _handle_list(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -209,7 +211,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.echo(" cognee config reset - Reset all to defaults")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}") from e
|
||||
|
||||
def _handle_reset(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -222,4 +224,4 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.echo("This would reset all settings to their default values")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}") from e
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
|
|||
from cognee.cli import DEFAULT_DOCS_URL
|
||||
import cognee.cli.echo as fmt
|
||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
|
||||
|
||||
|
||||
class DeleteCommand(SupportsCliCommand):
|
||||
|
|
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
|
|||
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
|
||||
return
|
||||
|
||||
# Build confirmation message
|
||||
# If --force is used, skip the preview and go straight to deletion
|
||||
if not args.force:
|
||||
# --- START PREVIEW LOGIC ---
|
||||
fmt.echo("Gathering data for preview...")
|
||||
try:
|
||||
preview_data = asyncio.run(
|
||||
get_deletion_counts(
|
||||
dataset_name=args.dataset_name,
|
||||
user_id=args.user_id,
|
||||
all_data=args.all,
|
||||
)
|
||||
)
|
||||
except CliCommandException as e:
|
||||
fmt.error(f"Error occured when fetching preview data: {str(e)}")
|
||||
return
|
||||
|
||||
if not preview_data:
|
||||
fmt.success("No data found to delete.")
|
||||
return
|
||||
|
||||
fmt.echo("You are about to delete:")
|
||||
fmt.echo(
|
||||
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
||||
)
|
||||
fmt.echo("-" * 20)
|
||||
# --- END PREVIEW LOGIC ---
|
||||
|
||||
# Build operation message for success/failure logging
|
||||
if args.all:
|
||||
confirm_msg = "Delete ALL data from cognee?"
|
||||
operation = "all data"
|
||||
|
|
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
|
|||
elif args.user_id:
|
||||
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
||||
operation = f"data for user '{args.user_id}'"
|
||||
else:
|
||||
operation = "data"
|
||||
|
||||
# Confirm deletion unless forced
|
||||
if not args.force:
|
||||
fmt.warning("This operation is irreversible!")
|
||||
if not fmt.confirm(confirm_msg):
|
||||
|
|
@ -64,17 +93,20 @@ Be careful with deletion operations as they are irreversible.
|
|||
# Run the async delete function
|
||||
async def run_delete():
|
||||
try:
|
||||
# NOTE: The underlying cognee.delete() function is currently not working as expected.
|
||||
# This is a separate bug that this preview feature helps to expose.
|
||||
if args.all:
|
||||
await cognee.delete(dataset_name=None, user_id=args.user_id)
|
||||
else:
|
||||
await cognee.delete(dataset_name=args.dataset_name, user_id=args.user_id)
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to delete: {str(e)}") from e
|
||||
|
||||
asyncio.run(run_delete())
|
||||
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
|
||||
fmt.success(f"Successfully deleted {operation}")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -31,10 +31,6 @@ Search Types & Use Cases:
|
|||
Traditional RAG using document chunks without graph structure.
|
||||
Best for: Direct document retrieval, specific fact-finding.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -108,7 +104,7 @@ Search Types & Use Cases:
|
|||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to search: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to search: {str(e)}") from e
|
||||
|
||||
results = asyncio.run(run_search())
|
||||
|
||||
|
|
@ -145,5 +141,5 @@ Search Types & Use Cases:
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error searching: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error searching: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ COMMAND_DESCRIPTIONS = {
|
|||
SEARCH_TYPE_CHOICES = [
|
||||
"GRAPH_COMPLETION",
|
||||
"RAG_COMPLETION",
|
||||
"INSIGHTS",
|
||||
"CHUNKS",
|
||||
"SUMMARIES",
|
||||
"CODE",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from cognee.modules.users.methods import get_user
|
|||
# for different async tasks, threads and processes
|
||||
vector_db_config = ContextVar("vector_db_config", default=None)
|
||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
|
||||
tavily_config = ContextVar("tavily_config", default=None)
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
|
|
|
|||
|
|
@ -50,26 +50,26 @@ class GraphConfig(BaseSettings):
|
|||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
# If no specific graph_filename or path are provided
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
provider = values.graph_database_provider.lower()
|
||||
def fill_derived(self):
|
||||
provider = self.graph_database_provider.lower()
|
||||
base_config = get_base_config()
|
||||
|
||||
# Set default filename if no filename is provided
|
||||
if not values.graph_filename:
|
||||
values.graph_filename = f"cognee_graph_{provider}"
|
||||
if not self.graph_filename:
|
||||
self.graph_filename = f"cognee_graph_{provider}"
|
||||
|
||||
# Handle graph file path
|
||||
if values.graph_file_path:
|
||||
if self.graph_file_path:
|
||||
# Check if absolute path is provided
|
||||
values.graph_file_path = ensure_absolute_path(
|
||||
os.path.join(values.graph_file_path, values.graph_filename)
|
||||
self.graph_file_path = ensure_absolute_path(
|
||||
os.path.join(self.graph_file_path, self.graph_filename)
|
||||
)
|
||||
else:
|
||||
# Default path
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
||||
self.graph_file_path = os.path.join(databases_directory_path, self.graph_filename)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -44,16 +44,14 @@ def create_graph_engine(
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- graph_database_provider: The type of graph database provider to use (e.g., neo4j,
|
||||
falkordb, kuzu).
|
||||
- graph_database_url: The URL for the graph database instance. Required for neo4j
|
||||
and falkordb providers.
|
||||
- graph_database_provider: The type of graph database provider to use (e.g., neo4j, falkor, kuzu).
|
||||
- graph_database_url: The URL for the graph database instance. Required for neo4j and falkordb providers.
|
||||
- graph_database_username: The username for authentication with the graph database.
|
||||
Required for neo4j provider.
|
||||
- graph_database_password: The password for authentication with the graph database.
|
||||
Required for neo4j provider.
|
||||
- graph_database_port: The port number for the graph database connection. Required
|
||||
for the falkordb provider.
|
||||
for the falkordb provider
|
||||
- graph_file_path: The filesystem path to the graph file. Required for the kuzu
|
||||
provider.
|
||||
|
||||
|
|
@ -86,21 +84,6 @@ def create_graph_engine(
|
|||
graph_database_name=graph_database_name or None,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "falkordb":
|
||||
if not (graph_database_url and graph_database_port):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=graph_database_url,
|
||||
database_port=graph_database_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "kuzu":
|
||||
if not graph_file_path:
|
||||
raise EnvironmentError("Missing required Kuzu database path.")
|
||||
|
|
@ -179,5 +162,5 @@ def create_graph_engine(
|
|||
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider: {graph_database_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,29 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
def _initialize_connection(self) -> None:
|
||||
"""Initialize the Kuzu database connection and schema."""
|
||||
|
||||
def _install_json_extension():
|
||||
"""
|
||||
Function handles installing of the json extension for the current Kuzu version.
|
||||
This has to be done with an empty graph db before connecting to an existing database otherwise
|
||||
missing json extension errors will be raised.
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
|
||||
temp_graph_file = temp_file.name
|
||||
tmp_db = Database(
|
||||
temp_graph_file,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
tmp_db.init_database()
|
||||
connection = Connection(tmp_db)
|
||||
connection.execute("INSTALL JSON;")
|
||||
except Exception as e:
|
||||
logger.info(f"JSON extension already installed or not needed: {e}")
|
||||
|
||||
_install_json_extension()
|
||||
|
||||
try:
|
||||
if "s3://" in self.db_path:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
||||
|
|
@ -109,11 +132,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
|
||||
try:
|
||||
self.connection.execute("INSTALL JSON;")
|
||||
except Exception as e:
|
||||
logger.info(f"JSON extension already installed or not needed: {e}")
|
||||
|
||||
try:
|
||||
self.connection.execute("LOAD EXTENSION JSON;")
|
||||
logger.info("Loaded JSON extension")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
auth=auth,
|
||||
max_connection_lifetime=120,
|
||||
notifications_min_severity="OFF",
|
||||
keep_alive=True,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
"properties": self.serialize_properties(dict(node)),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
|||
logger = get_logger("deadlock_retry")
|
||||
|
||||
|
||||
def deadlock_retry(max_retries=5):
|
||||
def deadlock_retry(max_retries=10):
|
||||
"""
|
||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
|
|||
return graph_id, region
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") from e
|
||||
|
||||
|
||||
def validate_graph_id(graph_id: str) -> bool:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
"""
|
||||
|
|
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
"Use this option only when vector data is required."
|
||||
)
|
||||
|
||||
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
|
||||
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
|
||||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||
logger.warning(
|
||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
||||
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
|
||||
"Defaulting to limit=10.",
|
||||
limit,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ class RelationalConfig(BaseSettings):
|
|||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
def fill_derived(self):
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.db_path:
|
||||
if not self.db_path:
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.db_path = databases_directory_path
|
||||
self.db_path = databases_directory_path
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class SQLAlchemyAdapter:
|
|||
try:
|
||||
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
|
||||
except (ValueError, NoResultFound) as e:
|
||||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
|
||||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}") from e
|
||||
|
||||
# Check if other data objects point to the same raw data location
|
||||
raw_data_location_entities = (
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
try:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
if limit is None:
|
||||
limit = await collection.count()
|
||||
|
||||
# If limit is still 0, no need to do the search, just return empty results
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
results = await collection.query(
|
||||
query_embeddings=[query_vector],
|
||||
include=["metadatas", "distances", "embeddings"]
|
||||
|
|
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
for row in vector_list
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search: {str(e)}")
|
||||
logger.warning(f"Error in search: {str(e)}")
|
||||
return []
|
||||
|
||||
async def batch_search(
|
||||
|
|
|
|||
|
|
@ -30,21 +30,21 @@ class VectorConfig(BaseSettings):
|
|||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def validate_paths(cls, values):
|
||||
def validate_paths(self):
|
||||
base_config = get_base_config()
|
||||
|
||||
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
|
||||
if values.vector_db_url and Path(values.vector_db_url).exists():
|
||||
if self.vector_db_url and Path(self.vector_db_url).exists():
|
||||
# Relative path to absolute
|
||||
values.vector_db_url = ensure_absolute_path(
|
||||
values.vector_db_url,
|
||||
self.vector_db_url = ensure_absolute_path(
|
||||
self.vector_db_url,
|
||||
)
|
||||
elif not values.vector_db_url:
|
||||
elif not self.vector_db_url:
|
||||
# Default path
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
self.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ def create_vector_engine(
|
|||
for each provider, raising an EnvironmentError if any are missing, or ImportError if the
|
||||
ChromaDB package is not installed.
|
||||
|
||||
Supported providers include: pgvector, FalkorDB, ChromaDB, and
|
||||
LanceDB.
|
||||
Supported providers include: pgvector, ChromaDB, and LanceDB.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
|
@ -79,18 +78,6 @@ def create_vector_engine(
|
|||
embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=vector_db_url,
|
||||
database_port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "chromadb":
|
||||
try:
|
||||
import chromadb
|
||||
|
|
|
|||
|
|
@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
|
|||
- int: An integer representing the number of dimensions in the embedding vector.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
model: Optional[str] = "openai/text-embedding-3-large",
|
||||
dimensions: Optional[int] = 3072,
|
||||
max_completion_tokens: int = 512,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.batch_size = batch_size
|
||||
# self.retry_count = 0
|
||||
self.embedding_model = TextEmbedding(model_name=model)
|
||||
|
||||
|
|
@ -88,7 +90,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
except Exception as error:
|
||||
logger.error(f"Embedding error in FastembedEmbeddingEngine: {str(error)}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}")
|
||||
raise EmbeddingException(
|
||||
f"Failed to index data points using model {self.model}"
|
||||
) from error
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
@ -101,6 +105,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Instantiate and return the tokenizer used for preparing text for embedding.
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
max_completion_tokens: int = 512,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
self.batch_size = batch_size
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
|
|
@ -148,7 +150,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
litellm.exceptions.NotFoundError,
|
||||
) as e:
|
||||
logger.error(f"Embedding error with model {self.model}: {str(e)}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
|
||||
|
||||
except Exception as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
|
|
@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Load and return the appropriate tokenizer for the specified model based on the provider.
|
||||
|
|
@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
|
||||
# count tokens as we calculate tokens word by word
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=None, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
|
||||
# tokenizer = GeminiTokenizer(
|
||||
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
|
||||
# )
|
||||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
|
|
|
|||
|
|
@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
max_completion_tokens: int = 512,
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.endpoint = endpoint
|
||||
self.huggingface_tokenizer_name = huggingface_tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
|
|
@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Load and return a HuggingFace tokenizer for the embedding engine.
|
||||
|
|
|
|||
|
|
@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
|
|||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
embedding_max_completion_tokens: Optional[int] = 8191
|
||||
embedding_batch_size: Optional[int] = None
|
||||
huggingface_tokenizer: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
|
||||
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
|
||||
self.embedding_batch_size = 2048
|
||||
elif not self.embedding_batch_size:
|
||||
self.embedding_batch_size = 100
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Serialize all embedding configuration settings to a dictionary.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
config.embedding_endpoint,
|
||||
config.embedding_api_key,
|
||||
config.embedding_api_version,
|
||||
config.embedding_batch_size,
|
||||
config.huggingface_tokenizer,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_provider,
|
||||
|
|
@ -46,6 +47,7 @@ def create_embedding_engine(
|
|||
embedding_endpoint,
|
||||
embedding_api_key,
|
||||
embedding_api_version,
|
||||
embedding_batch_size,
|
||||
huggingface_tokenizer,
|
||||
llm_api_key,
|
||||
llm_provider,
|
||||
|
|
@ -84,6 +86,7 @@ def create_embedding_engine(
|
|||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
||||
if embedding_provider == "ollama":
|
||||
|
|
@ -95,6 +98,7 @@ def create_embedding_engine(
|
|||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
endpoint=embedding_endpoint,
|
||||
huggingface_tokenizer=huggingface_tokenizer,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
|
|
@ -108,4 +112,5 @@ def create_embedding_engine(
|
|||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -238,11 +238,11 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
if limit is None:
|
||||
limit = await collection.count_rows()
|
||||
|
||||
# LanceDB search will break if limit is 0 so we must return
|
||||
if limit == 0:
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||
|
|
@ -265,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
|
|||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
|
@ -125,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
data_point_types = get_type_hints(DataPoint)
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
|
||||
class PGVectorDataPoint(Base):
|
||||
"""
|
||||
Represent a point in a vector data space with associated data and vector representation.
|
||||
class PGVectorDataPoint(Base):
|
||||
"""
|
||||
Represent a point in a vector data space with associated data and vector representation.
|
||||
|
||||
This class inherits from Base and is associated with a database table defined by
|
||||
__tablename__. It maintains the following public methods and instance variables:
|
||||
This class inherits from Base and is associated with a database table defined by
|
||||
__tablename__. It maintains the following public methods and instance variables:
|
||||
|
||||
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
||||
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
||||
|
||||
Instance variables:
|
||||
- id: Identifier for the data point, defined by data_point_types.
|
||||
- payload: JSON data associated with the data point.
|
||||
- vector: Vector representation of the data point, with size defined by vector_size.
|
||||
"""
|
||||
Instance variables:
|
||||
- id: Identifier for the data point, defined by data_point_types.
|
||||
- payload: JSON data associated with the data point.
|
||||
- vector: Vector representation of the data point, with size defined by vector_size.
|
||||
"""
|
||||
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
)
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||
|
|
@ -298,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
@ -310,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
if limit is None:
|
||||
async with self.get_async_session() as session:
|
||||
query = select(func.count()).select_from(PGVectorDataPoint)
|
||||
result = await session.execute(query)
|
||||
limit = result.scalar_one()
|
||||
|
||||
# If limit is still 0, no need to do the search, just return empty results
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
|
|||
collection_name: str,
|
||||
query_text: Optional[str],
|
||||
query_vector: Optional[List[float]],
|
||||
limit: int,
|
||||
limit: Optional[int],
|
||||
with_vector: bool = False,
|
||||
):
|
||||
"""
|
||||
|
|
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
|
|||
collection.
|
||||
- query_vector (Optional[List[float]]): An optional vector representation for
|
||||
searching the collection.
|
||||
- limit (int): The maximum number of results to return from the search.
|
||||
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||
- with_vector (bool): Whether to return the vector representations with search
|
||||
results. (default False)
|
||||
"""
|
||||
|
|
@ -106,7 +106,11 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
@abstractmethod
|
||||
async def batch_search(
|
||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: Optional[int],
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a batch search using multiple text queries against a collection.
|
||||
|
|
@ -116,7 +120,7 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
- collection_name (str): The name of the collection to conduct the batch search in.
|
||||
- query_texts (List[str]): A list of text queries to use for the search.
|
||||
- limit (int): The maximum number of results to return for each query.
|
||||
- limit (Optional[int]): The maximum number of results to return for each query.
|
||||
- with_vectors (bool): Whether to include vector representations with search
|
||||
results. (default False)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class S3FileStorage(Storage):
|
|||
self.s3 = s3fs.S3FileSystem(
|
||||
key=s3_config.aws_access_key_id,
|
||||
secret=s3_config.aws_secret_access_key,
|
||||
token=s3_config.aws_session_token,
|
||||
anon=False,
|
||||
endpoint_url=s3_config.aws_endpoint_url,
|
||||
client_kwargs={"region_name": s3_config.aws_region},
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class S3Config(BaseSettings):
|
|||
aws_endpoint_url: Optional[str] = None
|
||||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
aws_session_token: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class LLMConfig(BaseSettings):
|
|||
|
||||
structured_output_framework: str = "instructor"
|
||||
llm_provider: str = "openai"
|
||||
llm_model: str = "openai/gpt-4o-mini"
|
||||
llm_model: str = "openai/gpt-5-mini"
|
||||
llm_endpoint: str = ""
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_api_version: Optional[str] = None
|
||||
|
|
@ -48,7 +48,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_max_completion_tokens: int = 16384
|
||||
|
||||
baml_llm_provider: str = "openai"
|
||||
baml_llm_model: str = "gpt-4o-mini"
|
||||
baml_llm_model: str = "gpt-5-mini"
|
||||
baml_llm_endpoint: str = ""
|
||||
baml_llm_api_key: Optional[str] = None
|
||||
baml_llm_temperature: float = 0.0
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
|||
read due to an error.
|
||||
"""
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
try:
|
||||
if base_directory is None:
|
||||
base_directory = get_absolute_path("./infrastructure/llm/prompts")
|
||||
|
|
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
|||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}")
|
||||
logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: %s {e}")
|
||||
logger.error(f"An error occurred: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|||
- Summarizing large amounts of information
|
||||
- Quick understanding of complex subjects
|
||||
|
||||
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Discovering how entities are connected
|
||||
|
|
@ -95,9 +93,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|||
Query: "Summarize the key findings from these research papers"
|
||||
Response: `SUMMARIES`
|
||||
|
||||
Query: "What is the relationship between the methodologies used in these papers?"
|
||||
Response: `INSIGHTS`
|
||||
|
||||
Query: "When was Einstein born?"
|
||||
Response: `CHUNKS`
|
||||
|
||||
|
|
|
|||
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
Respond with: test
|
||||
|
|
@ -1,115 +1,155 @@
|
|||
import litellm
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class GeminiAdapter(LLMInterface):
|
||||
"""
|
||||
Handles interactions with a language model API.
|
||||
Adapter for Gemini API LLM provider.
|
||||
|
||||
Public methods include:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
This class initializes the API adapter with necessary credentials and configurations for
|
||||
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
||||
based on user input and system prompts.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
||||
Type[BaseModel]) -> BaseModel
|
||||
"""
|
||||
|
||||
MAX_RETRIES = 5
|
||||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
api_key: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate structured output from the language model based on the provided input and
|
||||
system prompt.
|
||||
Generate a response from a user query.
|
||||
|
||||
This method handles retries and raises a ValueError if the request fails or the response
|
||||
does not conform to the expected schema, logging errors accordingly.
|
||||
This asynchronous method sends a user query and a system prompt to a language model and
|
||||
retrieves the generated response. It handles API communication and retries up to a
|
||||
specified limit in case of request failures.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The user input text to generate a response for.
|
||||
- system_prompt (str): The system's prompt or context to influence the language
|
||||
model's generation.
|
||||
- response_model (Type[BaseModel]): A model type indicating the expected format of
|
||||
the response.
|
||||
- text_input (str): The input text from the user to generate a response for.
|
||||
- system_prompt (str): A prompt that provides context or instructions for the
|
||||
response generation.
|
||||
- response_model (Type[BaseModel]): A Pydantic model that defines the structure of
|
||||
the expected response.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- BaseModel: Returns the generated response as an instance of the specified response
|
||||
model.
|
||||
- BaseModel: An instance of the specified response model containing the structured
|
||||
output from the language model.
|
||||
"""
|
||||
|
||||
try:
|
||||
if response_model is str:
|
||||
response_schema = {"type": "string"}
|
||||
else:
|
||||
response_schema = response_model
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text_input},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=f"{self.model}",
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
temperature=0.1,
|
||||
response_format=response_schema,
|
||||
timeout=100,
|
||||
num_retries=self.MAX_RETRIES,
|
||||
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
if response_model is str:
|
||||
return content
|
||||
return response_model.model_validate_json(content)
|
||||
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Type
|
|||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
|
|
@ -56,9 +56,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
|
||||
)
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
|
|
@ -102,6 +100,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
@ -119,7 +118,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
|
|
@ -152,4 +151,4 @@ class GenericAPIAdapter(LLMInterface):
|
|||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|||
- ANTHROPIC: Represents the Anthropic provider.
|
||||
- CUSTOM: Represents a custom provider option.
|
||||
- GEMINI: Represents the Gemini provider.
|
||||
- MISTRAL: Represents the Mistral AI provider.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
|
@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def get_llm_client(raise_api_key_error: bool = True):
|
||||
|
|
@ -143,7 +145,36 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
streaming=llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class MistralAdapter(LLMInterface):
|
||||
"""
|
||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
"""
|
||||
|
||||
name = "Mistral"
|
||||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from the user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): The input text from the user to be processed.
|
||||
- system_prompt (str): A prompt that sets the context for the query.
|
||||
- response_model (Type[BaseModel]): The model to structure the response according to
|
||||
its format.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
else:
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): Input text from the user to be included in the prompt.
|
||||
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The formatted prompt string combining system prompt and user input.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
|
@ -5,15 +5,13 @@ from typing import Type
|
|||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
MissingSystemPromptPathError,
|
||||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
|
|
@ -29,9 +27,6 @@ observe = get_observe()
|
|||
|
||||
logger = get_logger()
|
||||
|
||||
# litellm to drop unsupported params, e.g., reasoning_effort when not supported by the model.
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
"""
|
||||
|
|
@ -76,8 +71,19 @@ class OpenAIAdapter(LLMInterface):
|
|||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||
# Make sure all new gpt models will work with this mode as well.
|
||||
if "gpt-5" in model:
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||
)
|
||||
self.client = instructor.from_litellm(
|
||||
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
|
||||
)
|
||||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
|
||||
self.transcription_model = transcription_model
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
|
|
@ -135,17 +141,16 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
reasoning_effort="minimal",
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
):
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from e
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
|
|
@ -178,7 +183,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
||||
@observe
|
||||
@sleep_and_retry_sync()
|
||||
|
|
@ -223,7 +228,6 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
reasoning_effort="minimal",
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List, Any
|
|||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
# NOTE: DEPRECATED as to count tokens you need to send an API request to Google it is too slow to use with Cognee
|
||||
class GeminiTokenizer(TokenizerInterface):
|
||||
"""
|
||||
Implements a tokenizer interface for the Gemini model, managing token extraction and
|
||||
|
|
@ -16,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
llm_model: str,
|
||||
max_completion_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.llm_model = llm_model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
# Get LLM API key from config
|
||||
|
|
@ -28,12 +29,11 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
get_llm_config,
|
||||
)
|
||||
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
|
||||
import google.generativeai as genai
|
||||
from google import genai
|
||||
|
||||
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
|
||||
self.client = genai.Client(api_key=llm_config.llm_api_key)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
"""
|
||||
|
|
@ -77,6 +77,7 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
|
||||
- int: The number of tokens in the given text.
|
||||
"""
|
||||
import google.generativeai as genai
|
||||
|
||||
return len(genai.embed_content(model=f"models/{self.model}", content=text))
|
||||
tokens_response = self.client.models.count_tokens(model=self.llm_model, contents=text)
|
||||
|
||||
return tokens_response.total_tokens
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def get_model_max_completion_tokens(model_name: str):
|
|||
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
|
||||
else:
|
||||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
logger.debug("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_completion_tokens
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class LoaderEngine:
|
|||
"image_loader",
|
||||
"audio_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
||||
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||
|
|
@ -86,7 +87,7 @@ class LoaderEngine:
|
|||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
||||
logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
|
||||
|
||||
# Try default priority order
|
||||
for loader_name in self.default_loader_priority:
|
||||
|
|
@ -95,7 +96,9 @@ class LoaderEngine:
|
|||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
||||
logger.info(
|
||||
f"Skipping {loader_name}: Loader not registered (in default priority list)."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -20,3 +20,10 @@ try:
|
|||
__all__.append("UnstructuredLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .advanced_pdf_loader import AdvancedPdfLoader
|
||||
|
||||
__all__.append("AdvancedPdfLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
|
||||
)
|
||||
raise ImportError from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PageBuffer:
|
||||
page_num: Optional[int]
|
||||
segments: List[str]
|
||||
|
||||
|
||||
class AdvancedPdfLoader(LoaderInterface):
|
||||
"""
|
||||
PDF loader using unstructured library.
|
||||
|
||||
Extracts text content, images, tables from PDF files page by page, providing
|
||||
structured page information and handling PDF-specific errors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return ["pdf"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return ["application/pdf"]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "advanced_pdf_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
|
||||
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||
**kwargs: Additional arguments passed to unstructured partition
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content and metadata
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Processing PDF: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
# Set partitioning parameters
|
||||
partition_kwargs: Dict[str, Any] = {
|
||||
"filename": file_path,
|
||||
"strategy": strategy,
|
||||
"infer_table_structure": True,
|
||||
"include_page_breaks": False,
|
||||
"include_metadata": True,
|
||||
**kwargs,
|
||||
}
|
||||
# Use partition to extract elements
|
||||
elements = partition_pdf(**partition_kwargs)
|
||||
|
||||
# Process elements into text content
|
||||
page_contents = self._format_elements_by_page(elements)
|
||||
|
||||
# Check if there is any content
|
||||
if not page_contents:
|
||||
logger.warning(
|
||||
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
|
||||
)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
# Combine all page outputs
|
||||
full_content = "\n".join(page_contents)
|
||||
|
||||
# Store the content
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, full_content)
|
||||
|
||||
return full_file_path
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
|
||||
logger.info("Falling back to PyPDF loader for %s", file_path)
|
||||
fallback_loader = PyPdfLoader()
|
||||
return await fallback_loader.load(file_path, **kwargs)
|
||||
|
||||
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
|
||||
"""Format elements by page."""
|
||||
page_buffers: List[_PageBuffer] = []
|
||||
current_buffer = _PageBuffer(page_num=None, segments=[])
|
||||
|
||||
for element in elements:
|
||||
element_dict = self._safe_to_dict(element)
|
||||
metadata = element_dict.get("metadata", {})
|
||||
page_num = metadata.get("page_number")
|
||||
|
||||
if current_buffer.page_num != page_num:
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
current_buffer = _PageBuffer(page_num=page_num, segments=[])
|
||||
|
||||
formatted = self._format_element(element_dict)
|
||||
|
||||
if formatted:
|
||||
current_buffer.segments.append(formatted)
|
||||
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
|
||||
page_contents: List[str] = []
|
||||
for buffer in page_buffers:
|
||||
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
|
||||
content = header + "\n\n".join(buffer.segments) + "\n"
|
||||
page_contents.append(str(content))
|
||||
return page_contents
|
||||
|
||||
def _format_element(
|
||||
self,
|
||||
element: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Format element."""
|
||||
element_type = element.get("type")
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
metadata = element.get("metadata", {})
|
||||
|
||||
if element_type.lower() == "table":
|
||||
return self._format_table_element(element) or text
|
||||
|
||||
if element_type.lower() == "image":
|
||||
description = text or self._format_image_element(metadata)
|
||||
return description
|
||||
|
||||
# Ignore header and footer
|
||||
if element_type.lower() in ["header", "footer"]:
|
||||
pass
|
||||
|
||||
return text
|
||||
|
||||
def _format_table_element(self, element: Dict[str, Any]) -> str:
|
||||
"""Format table element."""
|
||||
metadata = element.get("metadata", {})
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
table_html = metadata.get("text_as_html")
|
||||
|
||||
if table_html:
|
||||
return table_html.strip()
|
||||
|
||||
return text
|
||||
|
||||
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Format image."""
|
||||
placeholder = "[Image omitted]"
|
||||
image_text = placeholder
|
||||
coordinates = metadata.get("coordinates", {})
|
||||
points = coordinates.get("points") if isinstance(coordinates, dict) else None
|
||||
if points and isinstance(points, tuple) and len(points) == 4:
|
||||
leftup = points[0]
|
||||
rightdown = points[3]
|
||||
if (
|
||||
isinstance(leftup, tuple)
|
||||
and isinstance(rightdown, tuple)
|
||||
and len(leftup) == 2
|
||||
and len(rightdown) == 2
|
||||
):
|
||||
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
|
||||
|
||||
layout_width = coordinates.get("layout_width")
|
||||
layout_height = coordinates.get("layout_height")
|
||||
system = coordinates.get("system")
|
||||
if layout_width and layout_height and system:
|
||||
image_text = (
|
||||
image_text
|
||||
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
|
||||
)
|
||||
|
||||
return image_text
|
||||
|
||||
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
|
||||
"""Safe to dict."""
|
||||
try:
|
||||
if hasattr(element, "to_dict"):
|
||||
return element.to_dict()
|
||||
except Exception:
|
||||
pass
|
||||
fallback_type = getattr(element, "category", None)
|
||||
if not fallback_type:
|
||||
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
|
||||
|
||||
return {
|
||||
"type": fallback_type,
|
||||
"text": getattr(element, "text", ""),
|
||||
"metadata": getattr(element, "metadata", {}),
|
||||
}
|
||||
|
||||
def _clean_text(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
@ -16,3 +16,10 @@ try:
|
|||
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
|
||||
|
||||
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
||||
select(Data)
|
||||
.join(Data.datasets)
|
||||
.filter((Dataset.id == dataset_id))
|
||||
.order_by(Data.data_size.desc())
|
||||
)
|
||||
|
||||
data = list(result.scalars().all())
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue