Merge branch 'dev' into batch-document-handling

This commit is contained in:
Vasilije 2025-10-12 13:44:02 +02:00 committed by GitHub
commit c1d633fb75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
199 changed files with 16428 additions and 22653 deletions

View file

@ -16,7 +16,7 @@
STRUCTURED_OUTPUT_FRAMEWORK="instructor"
LLM_API_KEY="your_api_key"
LLM_MODEL="openai/gpt-4o-mini"
LLM_MODEL="openai/gpt-5-mini"
LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
@ -30,10 +30,13 @@ EMBEDDING_DIMENSIONS=3072
EMBEDDING_MAX_TOKENS=8191
# If embedding key is not provided same key set for LLM_API_KEY will be used
#EMBEDDING_API_KEY="your_api_key"
# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch,
# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models
#EMBEDDING_BATCH_SIZE=2048
# If using BAML structured output these env variables will be used
BAML_LLM_PROVIDER=openai
BAML_LLM_MODEL="gpt-4o-mini"
BAML_LLM_MODEL="gpt-5-mini"
BAML_LLM_ENDPOINT=""
BAML_LLM_API_KEY="your_api_key"
BAML_LLM_API_VERSION=""
@ -52,18 +55,18 @@ BAML_LLM_API_VERSION=""
################################################################################
# Configure storage backend (local filesystem or S3)
# STORAGE_BACKEND="local" # Default: uses local filesystem
#
#
# -- To switch to S3 storage, uncomment and fill these: ---------------------
# STORAGE_BACKEND="s3"
# STORAGE_BUCKET_NAME="your-bucket-name"
# AWS_REGION="us-east-1"
# AWS_ACCESS_KEY_ID="your-access-key"
# AWS_SECRET_ACCESS_KEY="your-secret-key"
#
#
# -- S3 Root Directories (optional) -----------------------------------------
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
#
#
# -- Cache Directory (auto-configured for S3) -------------------------------
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
# To override the automatic S3 cache location, uncomment:
@ -203,6 +206,16 @@ LITELLM_LOG="ERROR"
# DEFAULT_USER_EMAIL=""
# DEFAULT_USER_PASSWORD=""
################################################################################
# 📂 AWS Settings
################################################################################
#AWS_REGION=""
#AWS_ENDPOINT_URL=""
#AWS_ACCESS_KEY_ID=""
#AWS_SECRET_ACCESS_KEY=""
#AWS_SESSION_TOKEN=""
------------------------------- END OF POSSIBLE SETTINGS -------------------------------

View file

@ -58,7 +58,7 @@ body:
- Python version: [e.g. 3.9.0]
- Cognee version: [e.g. 0.1.0]
- LLM Provider: [e.g. OpenAI, Ollama]
- Database: [e.g. Neo4j, FalkorDB]
- Database: [e.g. Neo4j]
validations:
required: true

View file

@ -41,4 +41,4 @@ runs:
EXTRA_ARGS="$EXTRA_ARGS --extra $extra"
done
fi
uv sync --extra api --extra docs --extra evals --extra gemini --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS

View file

@ -1,8 +1,8 @@
<!-- .github/pull_request_template.md -->
## Description
<!--
Please provide a clear, human-generated description of the changes in this PR.
<!--
Please provide a clear, human-generated description of the changes in this PR.
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
-->
@ -16,15 +16,6 @@ DO NOT use AI-generated descriptions. We want to understand your thought process
- [ ] Performance improvement
- [ ] Other (please specify):
## Changes Made
<!-- List the specific changes made in this PR -->
-
-
-
## Testing
<!-- Describe how you tested your changes -->
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
@ -40,11 +31,5 @@ DO NOT use AI-generated descriptions. We want to understand your thought process
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## Related Issues
<!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" -->
## Additional Notes
<!-- Add any additional notes, concerns, or context for reviewers -->
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.

View file

@ -54,6 +54,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run Neo4j Example
env:
ENV: dev
@ -66,9 +70,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: |
uv run python examples/database_examples/neo4j_example.py

73
.github/workflows/distributed_test.yml vendored Normal file
View file

@ -0,0 +1,73 @@
name: Distributed Cognee test with modal
permissions:
contents: read
on:
workflow_call:
inputs:
python-version:
required: false
type: string
default: '3.11.x'
secrets:
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
jobs:
run-server-start-test:
name: Distributed Cognee test (Modal)
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "distributed postgres"
- name: Run Distributed Cognee (Modal)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
DB_PROVIDER: "postgres"
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
VECTOR_DB_PROVIDER: "pgvector"
COGNEE_DISTRIBUTED: "true"
run: uv run modal run ./distributed/entrypoint.py

View file

@ -1,5 +1,8 @@
name: Reusable Examples Tests
permissions:
contents: read
on:
workflow_call:
@ -131,3 +134,53 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/memify_coding_agent_example.py
test-permissions-example:
name: Run Permissions Example
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Memify Tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/permissions_example.py
test_docling_add:
name: Run Add with Docling Test
runs-on: macos-15
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: 'docling'
- name: Run Docling Test
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_add_docling_document.py

View file

@ -71,6 +71,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run default Neo4j
env:
ENV: 'dev'
@ -83,9 +87,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_neo4j.py
- name: Run Weighted Edges Tests with Neo4j

View file

@ -186,6 +186,10 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Install specific db dependency
run: echo "Dependencies already installed in setup"
@ -206,9 +210,9 @@ jobs:
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}

View file

@ -51,20 +51,6 @@ jobs:
name: Search test for Neo4j/LanceDB/Sqlite
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_search_db.py
run-kuzu-pgvector-postgres-search-tests:
@ -158,19 +148,6 @@ jobs:
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
postgres:
image: pgvector/pgvector:pg17
env:
@ -196,6 +173,10 @@ jobs:
python-version: ${{ inputs.python-version }}
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -213,9 +194,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'pgvector'
DB_PROVIDER: 'postgres'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -51,20 +51,6 @@ jobs:
name: Temporal Graph test Neo4j (lancedb + sqlite)
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_temporal_graph.py
run_temporal_graph_kuzu_postgres_pgvector:

View file

@ -43,7 +43,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
os: [ubuntu-22.04, macos-15, windows-latest]
fail-fast: false
steps:
- name: Check out
@ -79,7 +79,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -115,7 +115,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -151,7 +151,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -180,7 +180,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -210,7 +210,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out

View file

@ -27,7 +27,7 @@ jobs:
env:
LLM_PROVIDER: "gemini"
LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }}
LLM_MODEL: "gemini/gemini-1.5-flash"
LLM_MODEL: "gemini/gemini-2.0-flash"
EMBEDDING_PROVIDER: "gemini"
EMBEDDING_API_KEY: ${{ secrets.GEMINI_API_KEY }}
EMBEDDING_MODEL: "gemini/text-embedding-004"
@ -83,4 +83,4 @@ jobs:
EMBEDDING_MODEL: "openai/text-embedding-3-large"
EMBEDDING_DIMENSIONS: "3072"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
run: uv run python ./examples/python/simple_example.py

View file

@ -6,8 +6,12 @@ on:
permissions:
contents: read
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
test-gemini:
test-s3-storage:
name: Run S3 File Storage Test
runs-on: ubuntu-22.04
steps:

View file

@ -27,6 +27,12 @@ jobs:
uses: ./.github/workflows/e2e_tests.yml
secrets: inherit
distributed-tests:
name: Distributed Cognee Test
needs: [ basic-tests, e2e-tests, graph-db-tests ]
uses: ./.github/workflows/distributed_test.yml
secrets: inherit
cli-tests:
name: CLI Tests
uses: ./.github/workflows/cli_tests.yml
@ -104,7 +110,7 @@ jobs:
db-examples-tests:
name: DB Examples Tests
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
uses: ./.github/workflows/db_examples_tests.yml
secrets: inherit

View file

@ -101,3 +101,30 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pgvector.py
run-lancedb-tests:
name: LanceDB Tests
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
- name: Run LanceDB Tests
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_lancedb.py

View file

@ -86,12 +86,19 @@ jobs:
with:
python-version: '3.11'
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Weighted Edges Tests
env:
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
run: |
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short

View file

@ -22,6 +22,7 @@ RUN apt-get update && apt-get install -y \
libpq-dev \
git \
curl \
cmake \
clang \
build-essential \
&& rm -rf /var/lib/apt/lists/*
@ -31,7 +32,7 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
# Install the project's dependencies using the lockfile and settings
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
# Copy Alembic configuration
COPY alembic.ini /app/alembic.ini
@ -42,7 +43,7 @@ COPY alembic/ /app/alembic
COPY ./cognee /app/cognee
COPY ./distributed /app/distributed
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
FROM python:3.12-slim-bookworm

133
README.md
View file

@ -5,7 +5,7 @@
<br />
cognee - Memory for AI Agents in 5 lines of code
cognee - Memory for AI Agents in 6 lines of code
<p align="center">
<a href="https://www.youtube.com/watch?v=1bezuvLwJmw&t=2s">Demo</a>
@ -43,12 +43,10 @@
**🚀 We launched Cogwit beta (Fully-hosted AI Memory): Sign up [here](https://platform.cognee.ai/)! 🚀**
Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Extract, Cognify, Load) pipelines.
More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals)
<p align="center">
🌐 Available Languages
:
@ -70,53 +68,50 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
</div>
## Features
- Interconnect and retrieve your past conversations, documents, images and audio transcriptions
- Replaces RAG systems and reduces developer effort, and cost.
- Load data to graph and vector databases using only Pydantic
- Manipulate your data while ingesting from 30+ data sources
## Get Started
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a>
## About cognee
cognee works locally and stores your data on your device.
Our hosted solution is just our deployment of OSS cognee on Modal, with the goal of making development and productionization easier.
## Contributing
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
Self-hosted package:
- Interconnects any kind of documents: past conversations, files, images, and audio transcriptions
- Replaces RAG systems with a memory layer based on graphs and vectors
- Reduces developer effort and cost, while increasing quality and precision
- Provides Pythonic data pipelines that manage data ingestion from 30+ data sources
- Is highly customizable with custom tasks, pipelines, and a set of built-in search endpoints
Hosted platform:
- Includes a managed UI and a [hosted solution](https://www.cognee.ai)
## Self-Hosted (Open Source)
## 📦 Installation
### 📦 Installation
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager.
Cognee supports Python 3.10 to 3.13
Cognee supports Python 3.10 to 3.12
### With pip
#### With uv
```bash
pip install cognee
uv pip install cognee
```
## Local Cognee installation
Detailed instructions can be found in our [docs](https://docs.cognee.ai/getting-started/installation#environment-configuration)
You can install the local Cognee repo using **uv**, **pip** and **poetry**.
For local pip installation please make sure your pip version is above version 21.3.
### 💻 Basic Usage
### with UV with all optional dependencies
```bash
uv sync --all-extras
```
## 💻 Basic Usage
### Setup
#### Setup
```
import os
@ -125,10 +120,14 @@ os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
```
You can also set the variables by creating .env file, using our <a href="https://github.com/topoteretes/cognee/blob/main/.env.template">template.</a>
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai">documentation</a>
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai/setup-configuration/llm-providers">documentation</a>
### Simple example
#### Simple example
##### Python
This script will run the default pipeline:
@ -139,13 +138,16 @@ import asyncio
async def main():
# Add text to cognee
await cognee.add("Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.")
await cognee.add("Cognee turns documents into AI memory.")
# Generate the knowledge graph
await cognee.cognify()
# Add memory algorithms to the graph
await cognee.memify()
# Query the knowledge graph
results = await cognee.search("Tell me about NLP")
results = await cognee.search("What does cognee do?")
# Display the results
for result in results:
@ -158,33 +160,38 @@ if __name__ == '__main__':
```
Example output:
```
Natural Language Processing (NLP) is a cross-disciplinary and interdisciplinary field that involves computer science and information retrieval. It focuses on the interaction between computers and human language, enabling machines to understand and process natural language.
Cognee turns documents into AI memory.
```
##### Via CLI
## Our paper is out! <a href="https://arxiv.org/abs/2505.24478" target="_blank" rel="noopener noreferrer">Read here</a>
Let's get the basics covered
```
cognee-cli add "Cognee turns documents into AI memory."
cognee-cli cognify
cognee-cli search "What does cognee do?"
cognee-cli delete --all
```
or run
```
cognee-cli -ui
```
<div style="text-align: center">
<img src="assets/cognee-paper.png" alt="cognee paper" width="100%" />
</div>
</div>
## Cognee UI
You can also cognify your files and query using cognee UI.
### Hosted Platform
<img src="assets/cognee-new-ui.webp" width="100%" alt="Cognee UI 2"></a>
Get up and running in minutes with automatic updates, analytics, and enterprise security.
### Running the UI
1. Sign up on [cogwit](https://www.cognee.ai)
2. Add your API key to local UI and sync your data to Cogwit
Try cognee UI by setting LLM_API_KEY and running ``` cognee-cli -ui ``` command on your terminal.
## Understand our architecture
<div style="text-align: center">
<img src="assets/cognee_diagram.png" alt="cognee concept diagram" width="100%" />
</div>
@ -203,22 +210,26 @@ Try cognee UI by setting LLM_API_KEY and running ``` cognee-cli -ui ``` command
[cognee with local models](https://github.com/user-attachments/assets/8621d3e8-ecb8-4860-afb2-5594f2ee17db)
## Contributing
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
## Code of Conduct
We are committed to making open source an enjoyable and respectful experience for our community. See <a href="https://github.com/topoteretes/cognee/blob/main/CODE_OF_CONDUCT.md"><code>CODE_OF_CONDUCT</code></a> for more information.
## 💫 Contributors
## Citation
<a href="https://github.com/topoteretes/cognee/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=topoteretes/cognee"/>
</a>
We now have a paper you can cite:
## Sponsors
Thanks to the following companies for sponsoring the ongoing development of cognee.
- [GitHub's Secure Open Source Fund](https://resources.github.com/github-secure-open-source-fund/)
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=topoteretes/cognee&type=Date)](https://star-history.com/#topoteretes/cognee&Date)
```bibtex
@misc{markovic2025optimizinginterfaceknowledgegraphs,
title={Optimizing the Interface Between Knowledge Graphs and LLMs for Complex Reasoning},
author={Vasilije Markovic and Lazar Obradovic and Laszlo Hajdu and Jovan Pavlovic},
year={2025},
eprint={2505.24478},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2505.24478},
}
```

View file

@ -3,10 +3,18 @@
import classNames from "classnames";
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
import { forceCollide, forceManyBody } from "d3-force-3d";
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
import dynamic from "next/dynamic";
import { GraphControlsAPI } from "./GraphControls";
import getColorForNodeType from "./getColorForNodeType";
// Dynamically import ForceGraph to prevent SSR issues
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
ssr: false,
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
});
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
interface GraphVisuzaliationProps {
ref: MutableRefObject<GraphVisualizationAPI>;
data?: GraphData<NodeObject, LinkObject>;
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
const graphRef = useRef<ForceGraphMethods>();
useEffect(() => {
if (typeof window !== "undefined" && data && graphRef.current) {
if (data && graphRef.current) {
// add collision force
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
@ -209,63 +217,55 @@ export default function GraphVisualization({ ref, data, graphControls, className
const [graphShape, setGraphShape] = useState<string>();
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
durationMs?: number,
padding?: number,
nodeFilter?: (node: NodeObject) => boolean
) => {
if (!graphRef.current) {
console.warn("GraphVisualization: graphRef not ready yet");
return undefined as any;
}
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
};
useImperativeHandle(ref, () => ({
zoomToFit: graphRef.current!.zoomToFit,
setGraphShape: setGraphShape,
zoomToFit,
setGraphShape,
}));
return (
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
{(data && typeof window !== "undefined") ? (
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={300}
onDagError={handleDagError}
graphData={data}
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={data ? 300 : 100}
onDagError={handleDagError}
graphData={data || {
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label"
nodeRelSize={nodeSize}
nodeCanvasObject={renderNode}
nodeCanvasObjectMode={() => "replace"}
nodeLabel="label"
nodeRelSize={data ? nodeSize : 20}
nodeCanvasObject={data ? renderNode : renderInitialNode}
nodeCanvasObjectMode={() => data ? "replace" : "after"}
nodeAutoColorBy={data ? undefined : "type"}
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
onNodeClick={handleNodeClick}
onBackgroundClick={handleBackgroundClick}
d3VelocityDecay={0.3}
/>
) : (
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={100}
graphData={{
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label"
nodeRelSize={20}
nodeCanvasObject={renderInitialNode}
nodeCanvasObjectMode={() => "after"}
nodeAutoColorBy="type"
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
/>
)}
onNodeClick={handleNodeClick}
onBackgroundClick={handleBackgroundClick}
d3VelocityDecay={data ? 0.3 : undefined}
/>
</div>
);
}

View file

@ -89,15 +89,6 @@ export default function useChat(dataset: Dataset) {
}
interface Node {
name: string;
}
interface Relationship {
relationship_name: string;
}
type InsightMessage = [Node, Relationship, Node];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: string): string {
@ -106,14 +97,6 @@ function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: strin
}
switch (searchType) {
case "INSIGHTS":
return systemMessage.map((message: InsightMessage) => {
const [node1, relationship, node2] = message;
if (node1.name && node2.name) {
return `${node1.name} ${relationship.relationship_name} ${node2.name}.`;
}
return "";
}).join("\n");
case "SUMMARIES":
return systemMessage.map((message: { text: string }) => message.text).join("\n");
case "CHUNKS":

View file

@ -2,10 +2,11 @@
import Link from "next/link";
import Image from "next/image";
import { useBoolean } from "@/utils";
import { useEffect } from "react";
import { useBoolean, fetch } from "@/utils";
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
import { CTAButton, GhostButton, IconButton, Modal } from "../elements";
import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements";
import syncData from "@/modules/cloud/syncData";
interface HeaderProps {
@ -23,6 +24,12 @@ export default function Header({ user }: HeaderProps) {
setFalse: closeSyncModal,
} = useBoolean(false);
const {
value: isMCPConnected,
setTrue: setMCPConnected,
setFalse: setMCPDisconnected,
} = useBoolean(false);
const handleDataSyncConfirm = () => {
syncData()
.finally(() => {
@ -30,6 +37,19 @@ export default function Header({ user }: HeaderProps) {
});
};
useEffect(() => {
const checkMCPConnection = () => {
fetch.checkMCPHealth()
.then(() => setMCPConnected())
.catch(() => setMCPDisconnected());
};
checkMCPConnection();
const interval = setInterval(checkMCPConnection, 30000);
return () => clearInterval(interval);
}, [setMCPConnected, setMCPDisconnected]);
return (
<>
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
@ -39,6 +59,10 @@ export default function Header({ user }: HeaderProps) {
</div>
<div className="flex flex-row items-center gap-2.5">
<Link href="/mcp-status" className="!text-indigo-600 pl-4 pr-4">
<StatusDot className="mr-2" isActive={isMCPConnected} />
{ isMCPConnected ? "MCP connected" : "MCP disconnected" }
</Link>
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
<CloudIcon />
<div>Sync</div>

View file

@ -0,0 +1,13 @@
import React from "react";
const StatusDot = ({ isActive, className }: { isActive: boolean, className?: string }) => {
return (
<span
className={`inline-block w-3 h-3 rounded-full ${className} ${
isActive ? "bg-green-500" : "bg-red-500"
}`}
/>
);
};
export default StatusDot;

View file

@ -8,5 +8,6 @@ export { default as IconButton } from "./IconButton";
export { default as GhostButton } from "./GhostButton";
export { default as NeutralButton } from "./NeutralButton";
export { default as StatusIndicator } from "./StatusIndicator";
export { default as StatusDot } from "./StatusDot";
export { default as Accordion } from "./Accordion";
export { default as Notebook } from "./Notebook";

View file

@ -9,6 +9,8 @@ const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localho
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
const mcpApiUrl = process.env.NEXT_PUBLIC_MCP_API_URL || "http://localhost:8001";
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
let accessToken: string | null = null;
@ -49,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
)
.then((response) => handleServerErrors(response, retry, useCloud))
.catch((error) => {
// Handle network errors more gracefully
if (error.name === 'TypeError' && error.message.includes('fetch')) {
return Promise.reject(
new Error("Backend server is not responding. Please check if the server is running.")
);
}
if (error.detail === undefined) {
return Promise.reject(
new Error("No connection to the server.")
@ -62,8 +71,31 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
});
}
fetch.checkHealth = () => {
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
fetch.checkHealth = async () => {
const maxRetries = 5;
const retryDelay = 1000; // 1 second
for (let i = 0; i < maxRetries; i++) {
try {
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
if (response.ok) {
return response;
}
} catch (error) {
// If this is the last retry, throw the error
if (i === maxRetries - 1) {
throw error;
}
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, retryDelay));
}
}
throw new Error("Backend server is not responding after multiple attempts");
};
fetch.checkMCPHealth = () => {
return global.fetch(`${mcpApiUrl.replace("/api", "")}/health`);
};
fetch.setApiKey = (newApiKey: string) => {

View file

@ -1,153 +0,0 @@
import sys
import asyncio
try:
import cognee
from PySide6.QtWidgets import (
QApplication,
QWidget,
QPushButton,
QLineEdit,
QFileDialog,
QVBoxLayout,
QHBoxLayout,
QLabel,
QMessageBox,
QTextEdit,
QProgressDialog,
)
from PySide6.QtCore import Qt
from qasync import QEventLoop # Import QEventLoop from qasync
except ImportError as e:
print(
"\nPlease install Cognee with optional gui dependencies or manually install missing dependencies.\n"
)
print("\nTo install with poetry use:")
print("\npoetry install -E gui\n")
print("\nOr to install with poetry and all dependencies use:")
print("\npoetry install --all-extras\n")
print("\nTo install with pip use: ")
print('\npip install ".[gui]"\n')
raise e
class FileSearchApp(QWidget):
def __init__(self):
super().__init__()
self.selected_file = None
self.init_ui()
def init_ui(self):
# Horizontal layout for file upload and visualization buttons
button_layout = QHBoxLayout()
# Button to open file dialog
self.file_button = QPushButton("Upload File to Cognee", parent=self)
self.file_button.clicked.connect(self.open_file_dialog)
button_layout.addWidget(self.file_button)
# Button to visualize data
self.visualize_button = QPushButton("Visualize Data", parent=self)
self.visualize_button.clicked.connect(lambda: asyncio.ensure_future(self.visualize_data()))
button_layout.addWidget(self.visualize_button)
# Label to display selected file path
self.file_label = QLabel("No file selected", parent=self)
# Line edit for search input
self.search_input = QLineEdit(parent=self)
self.search_input.setPlaceholderText("Enter text to search...")
# Button to perform search; schedule the async search on click
self.search_button = QPushButton("Cognee Search", parent=self)
self.search_button.clicked.connect(lambda: asyncio.ensure_future(self._cognee_search()))
# Text output area for search results
self.result_output = QTextEdit(parent=self)
self.result_output.setReadOnly(True)
self.result_output.setPlaceholderText("Search results will appear here...")
# Progress dialog
self.progress_dialog = QProgressDialog("Processing..", None, 0, 0, parent=self)
self.progress_dialog.setWindowModality(Qt.WindowModal)
self.progress_dialog.setCancelButton(None) # Remove the cancel button
self.progress_dialog.close()
# Layout setup
layout = QVBoxLayout()
layout.addLayout(button_layout)
layout.addWidget(self.file_label)
layout.addWidget(self.search_input)
layout.addWidget(self.search_button)
layout.addWidget(self.result_output)
self.setLayout(layout)
self.setWindowTitle("Cognee")
self.resize(500, 300)
def open_file_dialog(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "Select a File", "", "All Files (*.*);;Text Files (*.txt)"
)
if file_path:
self.selected_file = file_path
self.file_label.setText(f"Selected: {file_path}")
asyncio.ensure_future(self.process_file_async())
async def process_file_async(self):
"""Asynchronously add and process the selected file."""
# Disable the entire window
self.progress_dialog.show()
self.setEnabled(False)
try:
await cognee.add(self.selected_file)
await cognee.cognify()
except Exception as e:
QMessageBox.critical(self, "Error", f"File processing failed: {str(e)}")
# Once finished, re-enable the window
self.setEnabled(True)
self.progress_dialog.close()
async def _cognee_search(self):
"""Performs an async search and updates the result output."""
# Disable the entire window
self.setEnabled(False)
self.progress_dialog.show()
try:
search_text = self.search_input.text().strip()
result = await cognee.search(query_text=search_text)
print(result)
# Assuming result is a list-like object; adjust if necessary
self.result_output.setText(result[0])
except Exception as e:
QMessageBox.critical(self, "Error", f"Search failed: {str(e)}")
# Once finished, re-enable the window
self.setEnabled(True)
self.progress_dialog.close()
async def visualize_data(self):
"""Async slot for handling visualize data button press."""
import webbrowser
from cognee.api.v1.visualize.visualize import visualize_graph
import os
import pathlib
html_file = os.path.join(pathlib.Path(__file__).parent, ".data", "graph_visualization.html")
await visualize_graph(html_file)
webbrowser.open(f"file://{html_file}")
if __name__ == "__main__":
app = QApplication(sys.argv)
# Create a qasync event loop and set it as the current event loop
loop = QEventLoop(app)
asyncio.set_event_loop(loop)
window = FileSearchApp()
window.show()
with loop:
loop.run_forever()

View file

@ -266,7 +266,7 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
- **codify**: Analyse a code repository, build a code graph, stores it in memory
- **search**: Query memory supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, INSIGHTS
- **search**: Query memory supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
- **list_data**: List all datasets and their data items with IDs for deletion operations

View file

@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
if [ "$DEBUG" = "true" ]; then
echo "Waiting for the debugger to attach..."
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec cognee --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration
fi
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec cognee --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration
fi
fi

View file

@ -8,7 +8,8 @@ requires-python = ">=3.10"
dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.2",
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
"fastmcp>=2.10.0,<3.0.0",
"mcp>=1.12.0,<2.0.0",
"uv>=0.6.3,<1.0.0",
@ -36,4 +37,4 @@ dev = [
allow-direct-references = true
[project.scripts]
cognee = "src:main"
cognee-mcp = "src:main"

View file

@ -1,8 +1,57 @@
from .server import main as server_main
import warnings
import sys
def main():
"""Main entry point for the package."""
"""Deprecated main entry point for the package."""
import asyncio
deprecation_notice = """
DEPRECATION NOTICE
The CLI entry-point used to start the Cognee MCP service has been renamed from
"cognee" to "cognee-mcp". Calling the old entry-point will stop working in a
future release.
WHAT YOU NEED TO DO:
Locate every place where you launch the MCP process and replace the final
argument cognee cognee-mcp.
For the example mcpServers block from Cursor shown below the change is:
{
"mcpServers": {
"Cognee": {
"command": "uv",
"args": [
"--directory",
"/path/to/cognee-mcp",
"run",
"cognee" // <-- CHANGE THIS to "cognee-mcp"
]
}
}
}
Continuing to use the old "cognee" entry-point will result in failures once it
is removed, so please update your configuration and any shell scripts as soon
as possible.
"""
warnings.warn(
"The 'cognee' command for cognee-mcp is deprecated and will be removed in a future version. "
"Please use 'cognee-mcp' instead to avoid conflicts with the main cognee library.",
DeprecationWarning,
stacklevel=2,
)
print("⚠️ DEPRECATION WARNING", file=sys.stderr)
print(deprecation_notice, file=sys.stderr)
asyncio.run(server_main())
def main_mcp():
"""Clean main entry point for cognee-mcp command."""
import asyncio
asyncio.run(server_main())

View file

@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
await index_graph_edges()
await index_graph_edges(edges_to_save)

View file

@ -19,6 +19,10 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder
from starlette.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import uvicorn
try:
@ -38,6 +42,53 @@ mcp = FastMCP("Cognee")
logger = get_logger()
async def run_sse_with_cors():
"""Custom SSE transport with CORS middleware."""
sse_app = mcp.sse_app()
sse_app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["GET"],
allow_headers=["*"],
)
config = uvicorn.Config(
sse_app,
host=mcp.settings.host,
port=mcp.settings.port,
log_level=mcp.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
async def run_http_with_cors():
"""Custom HTTP transport with CORS middleware."""
http_app = mcp.streamable_http_app()
http_app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["GET"],
allow_headers=["*"],
)
config = uvicorn.Config(
http_app,
host=mcp.settings.host,
port=mcp.settings.port,
log_level=mcp.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request):
return JSONResponse({"status": "ok"})
@mcp.tool()
async def cognee_add_developer_rules(
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
@ -204,7 +255,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks
@ -427,11 +478,6 @@ async def search(search_query: str, search_type: str) -> list:
Best for: Direct document retrieval, specific fact-finding.
Returns: LLM responses based on relevant text chunks.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
Returns: Formatted relationship data and entity connections.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -473,7 +519,6 @@ async def search(search_query: str, search_type: str) -> list:
- "RAG_COMPLETION": Returns an LLM response based on the search query and standard RAG data
- "CODE": Returns code-related knowledge in JSON format
- "CHUNKS": Returns raw text chunks from the knowledge graph
- "INSIGHTS": Returns relationships between nodes in readable format
- "SUMMARIES": Returns pre-generated hierarchical summaries
- "CYPHER": Direct graph database queries
- "FEELING_LUCKY": Automatically selects best search type
@ -486,7 +531,6 @@ async def search(search_query: str, search_type: str) -> list:
A list containing a single TextContent object with the search results.
The format of the result depends on the search_type:
- **GRAPH_COMPLETION/RAG_COMPLETION**: Conversational AI response strings
- **INSIGHTS**: Formatted relationship descriptions and entity connections
- **CHUNKS**: Relevant text passages with source metadata
- **SUMMARIES**: Hierarchical summaries from general to specific
- **CODE**: Structured code information with context
@ -496,7 +540,6 @@ async def search(search_query: str, search_type: str) -> list:
Performance & Optimization:
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
- **CHUNKS**: Fastest, pure vector similarity search without LLM
- **SUMMARIES**: Fast, returns pre-computed summaries
- **CODE**: Medium speed, specialized for code understanding
@ -535,9 +578,6 @@ async def search(search_query: str, search_type: str) -> list:
return str(search_results[0])
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
else:
return str(search_results)
@ -975,12 +1015,12 @@ async def main():
await mcp.run_stdio_async()
elif args.transport == "sse":
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
await mcp.run_sse_async()
await run_sse_with_cors()
elif args.transport == "http":
logger.info(
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
)
await mcp.run_streamable_http_async()
await run_http_with_cors()
if __name__ == "__main__":

3529
cognee-mcp/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,7 @@ from .api.v1.add import add
from .api.v1.delete import delete
from .api.v1.cognify import cognify
from .modules.memify import memify
from .api.v1.update import update
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune

View file

@ -28,6 +28,7 @@ from cognee.api.v1.add.routers import get_add_router
from cognee.api.v1.delete.routers import get_delete_router
from cognee.api.v1.responses.routers import get_responses_router
from cognee.api.v1.sync.routers import get_sync_router
from cognee.api.v1.update.routers import get_update_router
from cognee.api.v1.users.routers import (
get_auth_router,
get_register_router,
@ -263,6 +264,8 @@ app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["vi
app.include_router(get_delete_router(), prefix="/api/v1/delete", tags=["delete"])
app.include_router(get_update_router(), prefix="/api/v1/update", tags=["update"])
app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["responses"])
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])

View file

@ -189,12 +189,12 @@ class HealthChecker:
start_time = time.time()
try:
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm import LLMGateway
config = get_llm_config()
# Test actual API connection with minimal request
LLMGateway.show_prompt("test", "test")
from cognee.infrastructure.llm.utils import test_llm_connection
await test_llm_connection()
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
@ -217,13 +217,9 @@ class HealthChecker:
"""Check embedding service health (non-critical)."""
start_time = time.time()
try:
from cognee.infrastructure.databases.vector.embeddings.get_embedding_engine import (
get_embedding_engine,
)
from cognee.infrastructure.llm.utils import test_embedding_connection
# Test actual embedding generation with minimal text
engine = get_embedding_engine()
await engine.embed_text(["test"])
await test_embedding_connection()
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
@ -245,16 +241,6 @@ class HealthChecker:
"""Get comprehensive health status."""
components = {}
# Critical services
critical_components = [
"relational_db",
"vector_db",
"graph_db",
"file_storage",
"llm_provider",
"embedding_service",
]
critical_checks = [
("relational_db", self.check_relational_db()),
("vector_db", self.check_vector_db()),
@ -300,11 +286,11 @@ class HealthChecker:
else:
components[name] = result
critical_comps = [check[0] for check in critical_checks]
# Determine overall status
critical_unhealthy = any(
comp.status == HealthStatus.UNHEALTHY
comp.status == HealthStatus.UNHEALTHY and name in critical_comps
for name, comp in components.items()
if name in critical_components
)
has_degraded = any(comp.status == HealthStatus.DEGRADED for comp in components.values())

View file

@ -1,6 +1,8 @@
from uuid import UUID
from typing import Union, BinaryIO, List, Optional
import os
from typing import Union, BinaryIO, List, Optional, Dict, Any
from pydantic import BaseModel
from urllib.parse import urlparse
from cognee.modules.users.models import User
from cognee.modules.pipelines import Task, run_pipeline
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
@ -11,6 +13,19 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
)
from cognee.modules.engine.operations.setup import setup
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
from cognee.shared.logging_utils import get_logger
logger = get_logger()
try:
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
from cognee.context_global_variables import (
tavily_config as tavily,
soup_crawler_config as soup_crawler,
)
except ImportError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
async def add(
@ -23,12 +38,15 @@ async def add(
dataset_id: Optional[UUID] = None,
preferred_loaders: List[str] = None,
incremental_loading: bool = True,
extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None,
):
"""
Add data to Cognee for knowledge graph processing.
This is the first step in the Cognee workflow - it ingests raw data and prepares it
for processing. The function accepts various data formats including text, files, and
for processing. The function accepts various data formats including text, files, urls and
binary streams, then stores them in a specified dataset for further processing.
Prerequisites:
@ -68,6 +86,7 @@ async def add(
- S3 path: "s3://my-bucket/documents/file.pdf"
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
- Binary file object: open("file.txt", "rb")
- url: A web link url (https or http)
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
Create separate datasets to organize different knowledge domains.
user: User object for authentication and permissions. Uses default user if None.
@ -78,6 +97,9 @@ async def add(
vector_db_config: Optional configuration for vector database (for custom setups).
graph_db_config: Optional configuration for graph database (for custom setups).
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup
tavily_config: Optional configuration for Tavily API, including API key and extraction settings
soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules.
Returns:
PipelineRunInfo: Information about the ingestion pipeline execution including:
@ -126,6 +148,21 @@ async def add(
# Add a single file
await cognee.add("/home/user/documents/analysis.pdf")
# Add a single url and bs4 extract ingestion method
extraction_rules = {
"title": "h1",
"description": "p",
"more_info": "a[href*='more-info']"
}
await cognee.add("https://example.com",extraction_rules=extraction_rules)
# Add a single url and tavily extract ingestion method
Make sure to set TAVILY_API_KEY = YOUR_TAVILY_API_KEY as a environment variable
await cognee.add("https://example.com")
# Add multiple urls
await cognee.add(["https://example.com","https://books.toscrape.com"])
```
Environment Variables:
@ -133,22 +170,55 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
- LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password
- VECTOR_DB_PROVIDER: "lancedb" (default), "chromadb", "pgvector"
- GRAPH_DATABASE_PROVIDER: "kuzu" (default), "neo4j"
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
"""
try:
if not soup_crawler_config and extraction_rules:
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
if not tavily_config and os.getenv("TAVILY_API_KEY"):
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
soup_crawler.set(soup_crawler_config)
tavily.set(tavily_config)
http_schemes = {"http", "https"}
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
return isinstance(item, str) and urlparse(item).scheme in http_schemes
if _is_http_url(data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
except NameError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
tasks = [
Task(resolve_data_directories, include_subdirectories=True),
Task(ingest_data, dataset_name, user, node_set, dataset_id, preferred_loaders),
Task(
ingest_data,
dataset_name,
user,
node_set,
dataset_id,
preferred_loaders,
),
]
await setup()
user, authorized_dataset = await resolve_authorized_user_dataset(dataset_id, dataset_name, user)
user, authorized_dataset = await resolve_authorized_user_dataset(
dataset_name=dataset_name, dataset_id=dataset_id, user=user
)
await reset_dataset_pipeline_run_status(
authorized_dataset.id, user, pipeline_names=["add_pipeline", "cognify_pipeline"]

View file

@ -73,7 +73,11 @@ def get_add_router() -> APIRouter:
try:
add_run = await cognee_add(
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
data,
datasetName,
user=user,
dataset_id=datasetId,
node_set=node_set if node_set else None,
)
if isinstance(add_run, PipelineRunErrored):

View file

@ -148,7 +148,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks

View file

@ -94,9 +94,11 @@ def get_permissions_router() -> APIRouter:
from cognee.modules.users.roles.methods import create_role as create_role_method
await create_role_method(role_name=role_name, owner_id=user.id)
role_id = await create_role_method(role_name=role_name, owner_id=user.id)
return JSONResponse(status_code=200, content={"message": "Role created for tenant"})
return JSONResponse(
status_code=200, content={"message": "Role created for tenant", "role_id": str(role_id)}
)
@permissions_router.post("/users/{user_id}/roles")
async def add_user_to_role(
@ -212,8 +214,10 @@ def get_permissions_router() -> APIRouter:
from cognee.modules.users.tenants.methods import create_tenant as create_tenant_method
await create_tenant_method(tenant_name=tenant_name, user_id=user.id)
tenant_id = await create_tenant_method(tenant_name=tenant_name, user_id=user.id)
return JSONResponse(status_code=200, content={"message": "Tenant created."})
return JSONResponse(
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
)
return permissions_router

View file

@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
"type": "string",
"description": "Type of search to perform",
"enum": [
"INSIGHTS",
"CODE",
"GRAPH_COMPLETION",
"NATURAL_LANGUAGE",

View file

@ -59,7 +59,7 @@ async def handle_search(arguments: Dict[str, Any], user) -> list:
valid_search_types = (
search_tool["parameters"]["properties"]["search_type"]["enum"]
if search_tool
else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
else ["CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
)
if search_type_str not in valid_search_types:

View file

@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
"type": "string",
"description": "Type of search to perform",
"enum": [
"INSIGHTS",
"CODE",
"GRAPH_COMPLETION",
"NATURAL_LANGUAGE",

View file

@ -52,11 +52,6 @@ async def search(
Best for: Direct document retrieval, specific fact-finding.
Returns: LLM responses based on relevant text chunks.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
Returns: Formatted relationship data and entity connections.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -124,9 +119,6 @@ async def search(
**GRAPH_COMPLETION/RAG_COMPLETION**:
[List of conversational AI response strings]
**INSIGHTS**:
[List of formatted relationship descriptions and entity connections]
**CHUNKS**:
[List of relevant text passages with source metadata]
@ -146,7 +138,6 @@ async def search(
Performance & Optimization:
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
- **CHUNKS**: Fastest, pure vector similarity search without LLM
- **SUMMARIES**: Fast, returns pre-computed summaries
- **CODE**: Medium speed, specialized for code understanding

View file

@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
class LLMConfigInputDTO(InDTO):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
provider: Union[
Literal["openai"],
Literal["ollama"],
Literal["anthropic"],
Literal["gemini"],
Literal["mistral"],
]
model: str
api_key: str

View file

@ -1 +1 @@
from .ui import start_ui, stop_ui, ui
from .ui import start_ui

View file

@ -1,5 +1,7 @@
import os
import platform
import signal
import socket
import subprocess
import threading
import time
@ -7,7 +9,7 @@ import webbrowser
import zipfile
import requests
from pathlib import Path
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, List
import tempfile
import shutil
@ -17,6 +19,80 @@ from cognee.version import get_cognee_version
logger = get_logger()
def _stream_process_output(
process: subprocess.Popen, stream_name: str, prefix: str, color_code: str = ""
) -> threading.Thread:
"""
Stream output from a process with a prefix to identify the source.
Args:
process: The subprocess to monitor
stream_name: 'stdout' or 'stderr'
prefix: Text prefix for each line (e.g., '[BACKEND]', '[FRONTEND]')
color_code: ANSI color code for the prefix (optional)
Returns:
Thread that handles the streaming
"""
def stream_reader():
stream = getattr(process, stream_name)
if stream is None:
return
reset_code = "\033[0m" if color_code else ""
try:
for line in iter(stream.readline, b""):
if line:
line_text = line.decode("utf-8").rstrip()
if line_text:
print(f"{color_code}{prefix}{reset_code} {line_text}", flush=True)
except Exception:
pass
finally:
if stream:
stream.close()
thread = threading.Thread(target=stream_reader, daemon=True)
thread.start()
return thread
def _is_port_available(port: int) -> bool:
"""
Check if a port is available on localhost.
Returns True if the port is available, False otherwise.
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1) # 1 second timeout
result = sock.connect_ex(("localhost", port))
return result != 0 # Port is available if connection fails
except Exception:
return False
def _check_required_ports(ports_to_check: List[Tuple[int, str]]) -> Tuple[bool, List[str]]:
"""
Check if all required ports are available on localhost.
Args:
ports_to_check: List of (port, service_name) tuples
Returns:
Tuple of (all_available: bool, unavailable_services: List[str])
"""
unavailable = []
for port, service_name in ports_to_check:
if not _is_port_available(port):
unavailable.append(f"{service_name} (port {port})")
logger.error(f"Port {port} is already in use for {service_name}")
return len(unavailable) == 0, unavailable
def normalize_version_for_comparison(version: str) -> str:
"""
Normalize version string for comparison.
@ -214,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
Check if Node.js and npm are available.
Returns (is_available, error_message)
"""
try:
# Check Node.js
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
@ -223,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
node_version = result.stdout.strip()
logger.debug(f"Found Node.js version: {node_version}")
# Check npm
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10)
# Check npm - handle Windows PowerShell scripts
if platform.system() == "Windows":
# On Windows, npm might be a PowerShell script, so we need to use shell=True
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
)
else:
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
return False, "npm is not installed or not in PATH"
@ -246,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
Install frontend dependencies if node_modules doesn't exist.
This is needed for both development and downloaded frontends since both use npm run dev.
"""
node_modules = frontend_path / "node_modules"
if node_modules.exists():
logger.debug("Frontend dependencies already installed")
@ -254,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
logger.info("Installing frontend dependencies (this may take a few minutes)...")
try:
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
)
# Use shell=True on Windows for npm commands
if platform.system() == "Windows":
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
shell=True,
)
else:
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
)
if result.returncode == 0:
logger.info("Frontend dependencies installed successfully")
@ -327,55 +425,113 @@ def prompt_user_for_download() -> bool:
def start_ui(
pid_callback: Callable[[int], None],
host: str = "localhost",
port: int = 3000,
open_browser: bool = True,
auto_download: bool = False,
start_backend: bool = False,
backend_host: str = "localhost",
backend_port: int = 8000,
start_mcp: bool = False,
mcp_port: int = 8001,
) -> Optional[subprocess.Popen]:
"""
Start the cognee frontend UI server, optionally with the backend API server.
Start the cognee frontend UI server, optionally with the backend API server and MCP server.
This function will:
1. Optionally start the cognee backend API server
2. Find the cognee-frontend directory (development) or download it (pip install)
3. Check if Node.js and npm are available (for development mode)
4. Install dependencies if needed (development mode)
5. Start the frontend server
6. Optionally open the browser
2. Optionally start the cognee MCP server
3. Find the cognee-frontend directory (development) or download it (pip install)
4. Check if Node.js and npm are available (for development mode)
5. Install dependencies if needed (development mode)
6. Start the frontend server
7. Optionally open the browser
Args:
pid_callback: Callback to notify with PID of each spawned process
host: Host to bind the frontend server to (default: localhost)
port: Port to run the frontend server on (default: 3000)
open_browser: Whether to open the browser automatically (default: True)
auto_download: If True, download frontend without prompting (default: False)
start_backend: If True, also start the cognee API backend server (default: False)
backend_host: Host to bind the backend server to (default: localhost)
backend_port: Port to run the backend server on (default: 8000)
start_mcp: If True, also start the cognee MCP server (default: False)
mcp_port: Port to run the MCP server on (default: 8001)
Returns:
subprocess.Popen object representing the running frontend server, or None if failed
Note: If backend is started, it runs in a separate process that will be cleaned up
when the frontend process is terminated.
Note: If backend and/or MCP server are started, they run in separate processes
that will be cleaned up when the frontend process is terminated.
Example:
>>> import cognee
>>> def dummy_callback(pid): pass
>>> # Start just the frontend
>>> server = cognee.start_ui()
>>> server = cognee.start_ui(dummy_callback)
>>>
>>> # Start both frontend and backend
>>> server = cognee.start_ui(start_backend=True)
>>> server = cognee.start_ui(dummy_callback, start_backend=True)
>>> # UI will be available at http://localhost:3000
>>> # API will be available at http://localhost:8000
>>> # To stop both servers later:
>>>
>>> # Start frontend with MCP server
>>> server = cognee.start_ui(dummy_callback, start_mcp=True)
>>> # UI will be available at http://localhost:3000
>>> # MCP server will be available at http://127.0.0.1:8001/sse
>>> # To stop all servers later:
>>> server.terminate()
"""
logger.info("Starting cognee UI...")
ports_to_check = [(port, "Frontend UI")]
if start_backend:
ports_to_check.append((backend_port, "Backend API"))
if start_mcp:
ports_to_check.append((mcp_port, "MCP Server"))
logger.info("Checking port availability...")
all_ports_available, unavailable_services = _check_required_ports(ports_to_check)
if not all_ports_available:
error_msg = f"Cannot start cognee UI: The following services have ports already in use: {', '.join(unavailable_services)}"
logger.error(error_msg)
logger.error("Please stop the conflicting services or change the port configuration.")
return None
logger.info("✓ All required ports are available")
backend_process = None
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
cwd = os.getcwd()
env_file = os.path.join(cwd, ".env")
try:
image = "cognee/cognee-mcp:main"
subprocess.run(["docker", "pull", image], check=True)
mcp_process = subprocess.Popen(
[
"docker",
"run",
"-p",
f"{mcp_port}:8000",
"--rm",
"--env-file",
env_file,
"-e",
"TRANSPORT_MODE=sse",
"cognee/cognee-mcp:main",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
pid_callback(mcp_process.pid)
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
except Exception as e:
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
# Start backend server if requested
if start_backend:
logger.info("Starting cognee backend API server...")
@ -389,16 +545,19 @@ def start_ui(
"uvicorn",
"cognee.api.client:app",
"--host",
backend_host,
"localhost",
"--port",
str(backend_port),
],
# Inherit stdout/stderr from parent process to show logs
stdout=None,
stderr=None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream backend output with prefix
_stream_process_output(backend_process, "stdout", "[BACKEND]", "\033[32m") # Green
_stream_process_output(backend_process, "stderr", "[BACKEND]", "\033[32m") # Green
pid_callback(backend_process.pid)
# Give the backend a moment to start
@ -408,7 +567,7 @@ def start_ui(
logger.error("Backend server failed to start - process exited early")
return None
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
logger.info(f"✓ Backend API started at http://localhost:{backend_port}")
except Exception as e:
logger.error(f"Failed to start backend server: {str(e)}")
@ -453,24 +612,38 @@ def start_ui(
# Prepare environment variables
env = os.environ.copy()
env["HOST"] = host
env["HOST"] = "localhost"
env["PORT"] = str(port)
# Start the development server
logger.info(f"Starting frontend server at http://{host}:{port}")
logger.info(f"Starting frontend server at http://localhost:{port}")
logger.info("This may take a moment to compile and start...")
try:
# Create frontend in its own process group for clean termination
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Use shell=True on Windows for npm commands
if platform.system() == "Windows":
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
)
else:
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream frontend output with prefix
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
_stream_process_output(process, "stderr", "[FRONTEND]", "\033[33m") # Yellow
pid_callback(process.pid)
@ -479,10 +652,7 @@ def start_ui(
# Check if process is still running
if process.poll() is not None:
stdout, stderr = process.communicate()
logger.error("Frontend server failed to start:")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
logger.error("Frontend server failed to start - check the logs above for details")
return None
# Open browser if requested
@ -491,7 +661,7 @@ def start_ui(
def open_browser_delayed():
time.sleep(5) # Give Next.js time to fully start
try:
webbrowser.open(f"http://{host}:{port}") # TODO: use dashboard url?
webbrowser.open(f"http://localhost:{port}")
except Exception as e:
logger.warning(f"Could not open browser automatically: {e}")
@ -499,13 +669,9 @@ def start_ui(
browser_thread.start()
logger.info("✓ Cognee UI is starting up...")
logger.info(f"✓ Open your browser to: http://{host}:{port}")
logger.info(f"✓ Open your browser to: http://localhost:{port}")
logger.info("✓ The UI will be available once Next.js finishes compiling")
# Store backend process reference in the frontend process for cleanup
if backend_process:
process._cognee_backend_process = backend_process
return process
except Exception as e:
@ -523,102 +689,3 @@ def start_ui(
except (OSError, ProcessLookupError):
pass
return None
def stop_ui(process: subprocess.Popen) -> bool:
"""
Stop a running UI server process and backend process (if started), along with all their children.
Args:
process: The subprocess.Popen object returned by start_ui()
Returns:
bool: True if stopped successfully, False otherwise
"""
if not process:
return False
success = True
try:
# First, stop the backend process if it exists
backend_process = getattr(process, "_cognee_backend_process", None)
if backend_process:
logger.info("Stopping backend server...")
try:
backend_process.terminate()
try:
backend_process.wait(timeout=5)
logger.info("Backend server stopped gracefully")
except subprocess.TimeoutExpired:
logger.warning("Backend didn't terminate gracefully, forcing kill")
backend_process.kill()
backend_process.wait()
logger.info("Backend server stopped")
except Exception as e:
logger.error(f"Error stopping backend server: {str(e)}")
success = False
# Now stop the frontend process
logger.info("Stopping frontend server...")
# Try to terminate the process group (includes child processes like Next.js)
if hasattr(os, "killpg"):
try:
# Kill the entire process group
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
logger.debug("Sent SIGTERM to process group")
except (OSError, ProcessLookupError):
# Fall back to terminating just the main process
process.terminate()
logger.debug("Terminated main process only")
else:
process.terminate()
logger.debug("Terminated main process (Windows)")
try:
process.wait(timeout=10)
logger.info("Frontend server stopped gracefully")
except subprocess.TimeoutExpired:
logger.warning("Frontend didn't terminate gracefully, forcing kill")
# Force kill the process group
if hasattr(os, "killpg"):
try:
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
logger.debug("Sent SIGKILL to process group")
except (OSError, ProcessLookupError):
process.kill()
logger.debug("Force killed main process only")
else:
process.kill()
logger.debug("Force killed main process (Windows)")
process.wait()
if success:
logger.info("UI servers stopped successfully")
return success
except Exception as e:
logger.error(f"Error stopping UI servers: {str(e)}")
return False
# Convenience function similar to DuckDB's approach
def ui() -> Optional[subprocess.Popen]:
"""
Convenient alias for start_ui() with default parameters.
Similar to how DuckDB provides simple ui() function.
"""
return start_ui()
if __name__ == "__main__":
# Test the UI startup
server = start_ui()
if server:
try:
input("Press Enter to stop the server...")
finally:
stop_ui(server)

View file

@ -0,0 +1 @@
from .update import update

View file

@ -0,0 +1 @@
from .get_update_router import get_update_router

View file

@ -0,0 +1,90 @@
from fastapi.responses import JSONResponse
from fastapi import File, UploadFile, Depends, Form
from typing import Optional
from fastapi import APIRouter
from fastapi.encoders import jsonable_encoder
from typing import List
from uuid import UUID
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunErrored,
)
logger = get_logger()
def get_update_router() -> APIRouter:
router = APIRouter()
@router.patch("", response_model=None)
async def update(
data_id: UUID,
dataset_id: UUID,
data: List[UploadFile] = File(default=None),
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
user: User = Depends(get_authenticated_user),
):
"""
Update data in a dataset.
This endpoint updates existing documents in a specified dataset by providing the data_id of the existing document
to update and the new document with the changes as the data.
The document is updated, analyzed, and the changes are integrated into the knowledge graph.
## Request Parameters
- **data_id** (UUID): UUID of the document to update in Cognee memory
- **data** (List[UploadFile]): List of files to upload.
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
Used for grouping related data points in the knowledge graph.
## Response
Returns information about the add operation containing:
- Status of the operation
- Details about the processed data
- Any relevant metadata from the ingestion process
## Error Codes
- **400 Bad Request**: Neither datasetId nor datasetName provided
- **409 Conflict**: Error during add operation
- **403 Forbidden**: User doesn't have permission to add to dataset
## Notes
- To add data to datasets not owned by the user, use dataset_id (when ENABLE_BACKEND_ACCESS_CONTROL is set to True)
- datasetId value can only be the UUID of an already existing dataset
"""
send_telemetry(
"Update API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "PATCH /v1/update",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"node_set": str(node_set),
},
)
from cognee.api.v1.update import update as cognee_update
try:
update_run = await cognee_update(
data_id=data_id,
data=data,
dataset_id=dataset_id,
user=user,
node_set=node_set if node_set else None,
)
# If any cognify run errored return JSONResponse with proper error status code
if any(isinstance(v, PipelineRunErrored) for v in update_run.values()):
return JSONResponse(status_code=420, content=jsonable_encoder(update_run))
return update_run
except Exception as error:
logger.error(f"Error during deletion by data_id: {str(error)}")
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -0,0 +1,100 @@
from uuid import UUID
from typing import Union, BinaryIO, List, Optional
from cognee.modules.users.models import User
from cognee.api.v1.delete import delete
from cognee.api.v1.add import add
from cognee.api.v1.cognify import cognify
async def update(
data_id: UUID,
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_id: UUID,
user: User = None,
node_set: Optional[List[str]] = None,
vector_db_config: dict = None,
graph_db_config: dict = None,
preferred_loaders: List[str] = None,
incremental_loading: bool = True,
):
"""
Update existing data in Cognee.
Supported Input Types:
- **Text strings**: Direct text content (str) - any string not starting with "/" or "file://"
- **File paths**: Local file paths as strings in these formats:
* Absolute paths: "/path/to/document.pdf"
* File URLs: "file:///path/to/document.pdf" or "file://relative/path.txt"
* S3 paths: "s3://bucket-name/path/to/file.pdf"
- **Binary file objects**: File handles/streams (BinaryIO)
- **Lists**: Multiple files or text strings in a single call
Supported File Formats:
- Text files (.txt, .md, .csv)
- PDFs (.pdf)
- Images (.png, .jpg, .jpeg) - extracted via OCR/vision models
- Audio files (.mp3, .wav) - transcribed to text
- Code files (.py, .js, .ts, etc.) - parsed for structure and content
- Office documents (.docx, .pptx)
Workflow:
1. **Data Resolution**: Resolves file paths and validates accessibility
2. **Content Extraction**: Extracts text content from various file formats
3. **Dataset Storage**: Stores processed content in the specified dataset
4. **Metadata Tracking**: Records file metadata, timestamps, and user permissions
5. **Permission Assignment**: Grants user read/write/delete/share permissions on dataset
Args:
data_id: UUID of existing data to update
data: The latest version of the data. Can be:
- Single text string: "Your text content here"
- Absolute file path: "/path/to/document.pdf"
- File URL: "file:///absolute/path/to/document.pdf" or "file://relative/path.txt"
- S3 path: "s3://my-bucket/documents/file.pdf"
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
- Binary file object: open("file.txt", "rb")
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
Create separate datasets to organize different knowledge domains.
user: User object for authentication and permissions. Uses default user if None.
Default user: "default_user@example.com" (created automatically on first use).
Users can only access datasets they have permissions for.
node_set: Optional list of node identifiers for graph organization and access control.
Used for grouping related data points in the knowledge graph.
vector_db_config: Optional configuration for vector database (for custom setups).
graph_db_config: Optional configuration for graph database (for custom setups).
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
Returns:
PipelineRunInfo: Information about the ingestion pipeline execution including:
- Pipeline run ID for tracking
- Dataset ID where data was stored
- Processing status and any errors
- Execution timestamps and metadata
"""
await delete(
data_id=data_id,
dataset_id=dataset_id,
user=user,
)
await add(
data=data,
dataset_id=dataset_id,
user=user,
node_set=node_set,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
preferred_loaders=preferred_loaders,
incremental_loading=incremental_loading,
)
cognify_run = await cognify(
datasets=[dataset_id],
user=user,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
)
return cognify_run

View file

@ -183,10 +183,20 @@ def main() -> int:
for pid in spawned_pids:
try:
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
except (OSError, ProcessLookupError) as e:
if hasattr(os, "killpg"):
# Unix-like systems: Use process groups
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
else:
# Windows: Use taskkill to terminate process and its children
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(pid)],
capture_output=True,
check=False,
)
fmt.success(f"✓ Process {pid} and its children terminated.")
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
fmt.warning(f"Could not terminate process {pid}: {e}")
sys.exit(0)
@ -204,19 +214,27 @@ def main() -> int:
nonlocal spawned_pids
spawned_pids.append(pid)
frontend_port = 3000
start_backend, backend_port = True, 8000
start_mcp, mcp_port = True, 8001
server_process = start_ui(
host="localhost",
port=3000,
open_browser=True,
start_backend=True,
auto_download=True,
pid_callback=pid_callback,
port=frontend_port,
open_browser=True,
auto_download=True,
start_backend=start_backend,
backend_port=backend_port,
start_mcp=start_mcp,
mcp_port=mcp_port,
)
if server_process:
fmt.success("UI server started successfully!")
fmt.echo("The interface is available at: http://localhost:3000")
fmt.echo("The API backend is available at: http://localhost:8000")
fmt.echo(f"The interface is available at: http://localhost:{frontend_port}")
if start_backend:
fmt.echo(f"The API backend is available at: http://localhost:{backend_port}")
if start_mcp:
fmt.echo(f"The MCP server is available at: http://localhost:{mcp_port}")
fmt.note("Press Ctrl+C to stop the server...")
try:

View file

@ -70,11 +70,11 @@ After adding data, use `cognee cognify` to process it into knowledge graphs.
await cognee.add(data=data_to_add, dataset_name=args.dataset_name)
fmt.success(f"Successfully added data to dataset '{args.dataset_name}'")
except Exception as e:
raise CliCommandInnerException(f"Failed to add data: {str(e)}")
raise CliCommandInnerException(f"Failed to add data: {str(e)}") from e
asyncio.run(run_add())
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error adding data: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Failed to add data: {str(e)}", error_code=1) from e

View file

@ -107,7 +107,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
)
return result
except Exception as e:
raise CliCommandInnerException(f"Failed to cognify: {str(e)}")
raise CliCommandInnerException(f"Failed to cognify: {str(e)}") from e
result = asyncio.run(run_cognify())
@ -124,5 +124,5 @@ After successful cognify processing, use `cognee search` to query the knowledge
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1) from e

View file

@ -79,8 +79,10 @@ Configuration changes will affect how cognee processes and stores data.
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error managing configuration: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(
f"Error managing configuration: {str(e)}", error_code=1
) from e
def _handle_get(self, args: argparse.Namespace) -> None:
try:
@ -122,7 +124,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.note("Configuration viewing not fully implemented yet")
except Exception as e:
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}") from e
def _handle_set(self, args: argparse.Namespace) -> None:
try:
@ -141,7 +143,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.error(f"Failed to set configuration key '{args.key}'")
except Exception as e:
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}") from e
def _handle_unset(self, args: argparse.Namespace) -> None:
try:
@ -189,7 +191,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.note("Use 'cognee config list' to see all available configuration options")
except Exception as e:
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}") from e
def _handle_list(self, args: argparse.Namespace) -> None:
try:
@ -209,7 +211,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.echo(" cognee config reset - Reset all to defaults")
except Exception as e:
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}") from e
def _handle_reset(self, args: argparse.Namespace) -> None:
try:
@ -222,4 +224,4 @@ Configuration changes will affect how cognee processes and stores data.
fmt.echo("This would reset all settings to their default values")
except Exception as e:
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}") from e

View file

@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
from cognee.cli import DEFAULT_DOCS_URL
import cognee.cli.echo as fmt
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
class DeleteCommand(SupportsCliCommand):
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
return
# Build confirmation message
# If --force is used, skip the preview and go straight to deletion
if not args.force:
# --- START PREVIEW LOGIC ---
fmt.echo("Gathering data for preview...")
try:
preview_data = asyncio.run(
get_deletion_counts(
dataset_name=args.dataset_name,
user_id=args.user_id,
all_data=args.all,
)
)
except CliCommandException as e:
fmt.error(f"Error occured when fetching preview data: {str(e)}")
return
if not preview_data:
fmt.success("No data found to delete.")
return
fmt.echo("You are about to delete:")
fmt.echo(
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
)
fmt.echo("-" * 20)
# --- END PREVIEW LOGIC ---
# Build operation message for success/failure logging
if args.all:
confirm_msg = "Delete ALL data from cognee?"
operation = "all data"
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
elif args.user_id:
confirm_msg = f"Delete all data for user '{args.user_id}'?"
operation = f"data for user '{args.user_id}'"
else:
operation = "data"
# Confirm deletion unless forced
if not args.force:
fmt.warning("This operation is irreversible!")
if not fmt.confirm(confirm_msg):
@ -64,17 +93,20 @@ Be careful with deletion operations as they are irreversible.
# Run the async delete function
async def run_delete():
try:
# NOTE: The underlying cognee.delete() function is currently not working as expected.
# This is a separate bug that this preview feature helps to expose.
if args.all:
await cognee.delete(dataset_name=None, user_id=args.user_id)
else:
await cognee.delete(dataset_name=args.dataset_name, user_id=args.user_id)
except Exception as e:
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
raise CliCommandInnerException(f"Failed to delete: {str(e)}") from e
asyncio.run(run_delete())
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
fmt.success(f"Successfully deleted {operation}")
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1) from e

View file

@ -31,10 +31,6 @@ Search Types & Use Cases:
Traditional RAG using document chunks without graph structure.
Best for: Direct document retrieval, specific fact-finding.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -108,7 +104,7 @@ Search Types & Use Cases:
)
return results
except Exception as e:
raise CliCommandInnerException(f"Failed to search: {str(e)}")
raise CliCommandInnerException(f"Failed to search: {str(e)}") from e
results = asyncio.run(run_search())
@ -145,5 +141,5 @@ Search Types & Use Cases:
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error searching: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error searching: {str(e)}", error_code=1) from e

View file

@ -19,7 +19,6 @@ COMMAND_DESCRIPTIONS = {
SEARCH_TYPE_CHOICES = [
"GRAPH_COMPLETION",
"RAG_COMPLETION",
"INSIGHTS",
"CHUNKS",
"SUMMARIES",
"CODE",

View file

@ -12,6 +12,8 @@ from cognee.modules.users.methods import get_user
# for different async tasks, threads and processes
vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
tavily_config = ContextVar("tavily_config", default=None)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):

View file

@ -50,26 +50,26 @@ class GraphConfig(BaseSettings):
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
# If no specific graph_filename or path are provided
@pydantic.model_validator(mode="after")
def fill_derived(cls, values):
provider = values.graph_database_provider.lower()
def fill_derived(self):
provider = self.graph_database_provider.lower()
base_config = get_base_config()
# Set default filename if no filename is provided
if not values.graph_filename:
values.graph_filename = f"cognee_graph_{provider}"
if not self.graph_filename:
self.graph_filename = f"cognee_graph_{provider}"
# Handle graph file path
if values.graph_file_path:
if self.graph_file_path:
# Check if absolute path is provided
values.graph_file_path = ensure_absolute_path(
os.path.join(values.graph_file_path, values.graph_filename)
self.graph_file_path = ensure_absolute_path(
os.path.join(self.graph_file_path, self.graph_filename)
)
else:
# Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
self.graph_file_path = os.path.join(databases_directory_path, self.graph_filename)
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -44,16 +44,14 @@ def create_graph_engine(
Parameters:
-----------
- graph_database_provider: The type of graph database provider to use (e.g., neo4j,
falkordb, kuzu).
- graph_database_url: The URL for the graph database instance. Required for neo4j
and falkordb providers.
- graph_database_provider: The type of graph database provider to use (e.g., neo4j, falkor, kuzu).
- graph_database_url: The URL for the graph database instance. Required for neo4j and falkordb providers.
- graph_database_username: The username for authentication with the graph database.
Required for neo4j provider.
- graph_database_password: The password for authentication with the graph database.
Required for neo4j provider.
- graph_database_port: The port number for the graph database connection. Required
for the falkordb provider.
for the falkordb provider
- graph_file_path: The filesystem path to the graph file. Required for the kuzu
provider.
@ -86,21 +84,6 @@ def create_graph_engine(
graph_database_name=graph_database_name or None,
)
elif graph_database_provider == "falkordb":
if not (graph_database_url and graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url=graph_database_url,
database_port=graph_database_port,
embedding_engine=embedding_engine,
)
elif graph_database_provider == "kuzu":
if not graph_file_path:
raise EnvironmentError("Missing required Kuzu database path.")
@ -179,5 +162,5 @@ def create_graph_engine(
raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
)

View file

@ -48,6 +48,29 @@ class KuzuAdapter(GraphDBInterface):
def _initialize_connection(self) -> None:
"""Initialize the Kuzu database connection and schema."""
def _install_json_extension():
"""
Function handles installing of the json extension for the current Kuzu version.
This has to be done with an empty graph db before connecting to an existing database otherwise
missing json extension errors will be raised.
"""
try:
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
temp_graph_file = temp_file.name
tmp_db = Database(
temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
tmp_db.init_database()
connection = Connection(tmp_db)
connection.execute("INSTALL JSON;")
except Exception as e:
logger.info(f"JSON extension already installed or not needed: {e}")
_install_json_extension()
try:
if "s3://" in self.db_path:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
@ -109,11 +132,6 @@ class KuzuAdapter(GraphDBInterface):
self.db.init_database()
self.connection = Connection(self.db)
try:
self.connection.execute("INSTALL JSON;")
except Exception as e:
logger.info(f"JSON extension already installed or not needed: {e}")
try:
self.connection.execute("LOAD EXTENSION JSON;")
logger.info("Loaded JSON extension")

View file

@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
auth=auth,
max_connection_lifetime=120,
notifications_min_severity="OFF",
keep_alive=True,
)
async def initialize(self) -> None:
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
{
"node_id": str(node.id),
"label": type(node).__name__,
"properties": self.serialize_properties(node.model_dump()),
"properties": self.serialize_properties(dict(node)),
}
for node in nodes
]

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("deadlock_retry")
def deadlock_retry(max_retries=5):
def deadlock_retry(max_retries=10):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.

View file

@ -53,7 +53,7 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
return graph_id, region
except Exception as e:
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") from e
def validate_graph_id(graph_id: str) -> bool:

View file

@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
limit: Optional[int] = None,
with_vector: bool = False,
):
"""
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
"Use this option only when vector data is required."
)
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
"Defaulting to limit=10.",
limit,
)

View file

@ -23,14 +23,14 @@ class RelationalConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@pydantic.model_validator(mode="after")
def fill_derived(cls, values):
def fill_derived(self):
# Set file path based on graph database provider if no file path is provided
if not values.db_path:
if not self.db_path:
base_config = get_base_config()
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.db_path = databases_directory_path
self.db_path = databases_directory_path
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -283,7 +283,7 @@ class SQLAlchemyAdapter:
try:
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
except (ValueError, NoResultFound) as e:
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
raise EntityNotFoundError(message=f"Entity not found: {str(e)}") from e
# Check if other data objects point to the same raw data location
raw_data_location_entities = (

View file

@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
try:
collection = await self.get_collection(collection_name)
if limit == 0:
if limit is None:
limit = await collection.count()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"]
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
logger.warning(f"Error in search: {str(e)}")
return []
async def batch_search(

View file

@ -30,21 +30,21 @@ class VectorConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@pydantic.model_validator(mode="after")
def validate_paths(cls, values):
def validate_paths(self):
base_config = get_base_config()
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
if values.vector_db_url and Path(values.vector_db_url).exists():
if self.vector_db_url and Path(self.vector_db_url).exists():
# Relative path to absolute
values.vector_db_url = ensure_absolute_path(
values.vector_db_url,
self.vector_db_url = ensure_absolute_path(
self.vector_db_url,
)
elif not values.vector_db_url:
elif not self.vector_db_url:
# Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
self.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -19,8 +19,7 @@ def create_vector_engine(
for each provider, raising an EnvironmentError if any are missing, or ImportError if the
ChromaDB package is not installed.
Supported providers include: pgvector, FalkorDB, ChromaDB, and
LanceDB.
Supported providers include: pgvector, ChromaDB, and LanceDB.
Parameters:
-----------
@ -79,18 +78,6 @@ def create_vector_engine(
embedding_engine,
)
elif vector_db_provider == "falkordb":
if not (vector_db_url and vector_db_port):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url=vector_db_url,
database_port=vector_db_port,
embedding_engine=embedding_engine,
)
elif vector_db_provider == "chromadb":
try:
import chromadb

View file

@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
- int: An integer representing the number of dimensions in the embedding vector.
"""
raise NotImplementedError()
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
raise NotImplementedError()

View file

@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
model: Optional[str] = "openai/text-embedding-3-large",
dimensions: Optional[int] = 3072,
max_completion_tokens: int = 512,
batch_size: int = 100,
):
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.batch_size = batch_size
# self.retry_count = 0
self.embedding_model = TextEmbedding(model_name=model)
@ -88,7 +90,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
except Exception as error:
logger.error(f"Embedding error in FastembedEmbeddingEngine: {str(error)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}")
raise EmbeddingException(
f"Failed to index data points using model {self.model}"
) from error
def get_vector_size(self) -> int:
"""
@ -101,6 +105,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Instantiate and return the tokenizer used for preparing text for embedding.

View file

@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
endpoint: str = None,
api_version: str = None,
max_completion_tokens: int = 512,
batch_size: int = 100,
):
self.api_key = api_key
self.endpoint = endpoint
@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
self.batch_size = batch_size
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
@ -148,7 +150,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
litellm.exceptions.NotFoundError,
) as e:
logger.error(f"Embedding error with model {self.model}: {str(e)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}")
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
except Exception as error:
logger.error("Error embedding text: %s", str(error))
@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Load and return the appropriate tokenizer for the specified model based on the provider.
@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
# count tokens as we calculate tokens word by word
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
# tokenizer = GeminiTokenizer(
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
# )
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens

View file

@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
max_completion_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
batch_size: int = 100,
):
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.batch_size = batch_size
self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Load and return a HuggingFace tokenizer for the embedding engine.

View file

@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_completion_tokens: Optional[int] = 8191
embedding_batch_size: Optional[int] = None
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def model_post_init(self, __context) -> None:
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
self.embedding_batch_size = 2048
elif not self.embedding_batch_size:
self.embedding_batch_size = 100
def to_dict(self) -> dict:
"""
Serialize all embedding configuration settings to a dictionary.

View file

@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
config.embedding_endpoint,
config.embedding_api_key,
config.embedding_api_version,
config.embedding_batch_size,
config.huggingface_tokenizer,
llm_config.llm_api_key,
llm_config.llm_provider,
@ -46,6 +47,7 @@ def create_embedding_engine(
embedding_endpoint,
embedding_api_key,
embedding_api_version,
embedding_batch_size,
huggingface_tokenizer,
llm_api_key,
llm_provider,
@ -84,6 +86,7 @@ def create_embedding_engine(
model=embedding_model,
dimensions=embedding_dimensions,
max_completion_tokens=embedding_max_completion_tokens,
batch_size=embedding_batch_size,
)
if embedding_provider == "ollama":
@ -95,6 +98,7 @@ def create_embedding_engine(
max_completion_tokens=embedding_max_completion_tokens,
endpoint=embedding_endpoint,
huggingface_tokenizer=huggingface_tokenizer,
batch_size=embedding_batch_size,
)
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
@ -108,4 +112,5 @@ def create_embedding_engine(
model=embedding_model,
dimensions=embedding_dimensions,
max_completion_tokens=embedding_max_completion_tokens,
batch_size=embedding_batch_size,
)

View file

@ -226,7 +226,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
@ -238,11 +238,11 @@ class LanceDBAdapter(VectorDBInterface):
collection = await self.get_collection(collection_name)
if limit == 0:
if limit is None:
limit = await collection.count_rows()
# LanceDB search will break if limit is 0 so we must return
if limit == 0:
if limit <= 0:
return []
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
@ -265,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)

View file

@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -125,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
@ -298,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
@ -310,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
if limit is None:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.scalar_one()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []

View file

@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
collection_name: str,
query_text: Optional[str],
query_vector: Optional[List[float]],
limit: int,
limit: Optional[int],
with_vector: bool = False,
):
"""
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
collection.
- query_vector (Optional[List[float]]): An optional vector representation for
searching the collection.
- limit (int): The maximum number of results to return from the search.
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
"""
@ -106,7 +106,11 @@ class VectorDBInterface(Protocol):
@abstractmethod
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self,
collection_name: str,
query_texts: List[str],
limit: Optional[int],
with_vectors: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -116,7 +120,7 @@ class VectorDBInterface(Protocol):
- collection_name (str): The name of the collection to conduct the batch search in.
- query_texts (List[str]): A list of text queries to use for the search.
- limit (int): The maximum number of results to return for each query.
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
"""

View file

@ -34,6 +34,7 @@ class S3FileStorage(Storage):
self.s3 = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id,
secret=s3_config.aws_secret_access_key,
token=s3_config.aws_session_token,
anon=False,
endpoint_url=s3_config.aws_endpoint_url,
client_kwargs={"region_name": s3_config.aws_region},

View file

@ -8,6 +8,7 @@ class S3Config(BaseSettings):
aws_endpoint_url: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -39,7 +39,7 @@ class LLMConfig(BaseSettings):
structured_output_framework: str = "instructor"
llm_provider: str = "openai"
llm_model: str = "openai/gpt-4o-mini"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
@ -48,7 +48,7 @@ class LLMConfig(BaseSettings):
llm_max_completion_tokens: int = 16384
baml_llm_provider: str = "openai"
baml_llm_model: str = "gpt-4o-mini"
baml_llm_model: str = "gpt-5-mini"
baml_llm_endpoint: str = ""
baml_llm_api_key: Optional[str] = None
baml_llm_temperature: float = 0.0

View file

@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
read due to an error.
"""
logger = get_logger(level=ERROR)
try:
if base_directory is None:
base_directory = get_absolute_path("./infrastructure/llm/prompts")
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
except FileNotFoundError:
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}")
logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred: %s {e}")
logger.error(f"An error occurred: {e}")
return None

View file

@ -10,8 +10,6 @@ Here are the available `SearchType` tools and their specific functions:
- Summarizing large amounts of information
- Quick understanding of complex subjects
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
**Best for:**
- Discovering how entities are connected
@ -95,9 +93,6 @@ Here are the available `SearchType` tools and their specific functions:
Query: "Summarize the key findings from these research papers"
Response: `SUMMARIES`
Query: "What is the relationship between the methodologies used in these papers?"
Response: `INSIGHTS`
Query: "When was Einstein born?"
Response: `CHUNKS`

View file

@ -0,0 +1 @@
Respond with: test

View file

@ -1,115 +1,155 @@
import litellm
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
"""Adapter for Generic API LLM provider API"""
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
logger = get_logger()
observe = get_observe()
class GeminiAdapter(LLMInterface):
"""
Handles interactions with a language model API.
Adapter for Gemini API LLM provider.
Public methods include:
- acreate_structured_output
- show_prompt
This class initializes the API adapter with necessary credentials and configurations for
interacting with the gemini LLM models. It provides methods for creating structured outputs
based on user input and system prompts.
Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
"""
MAX_RETRIES = 5
name: str
model: str
api_key: str
def __init__(
self,
endpoint,
api_key: str,
model: str,
api_version: str,
max_completion_tokens: int,
endpoint: Optional[str] = None,
api_version: Optional[str] = None,
streaming: bool = False,
) -> None:
self.api_key = api_key
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation")
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate structured output from the language model based on the provided input and
system prompt.
Generate a response from a user query.
This method handles retries and raises a ValueError if the request fails or the response
does not conform to the expected schema, logging errors accordingly.
This asynchronous method sends a user query and a system prompt to a language model and
retrieves the generated response. It handles API communication and retries up to a
specified limit in case of request failures.
Parameters:
-----------
- text_input (str): The user input text to generate a response for.
- system_prompt (str): The system's prompt or context to influence the language
model's generation.
- response_model (Type[BaseModel]): A model type indicating the expected format of
the response.
- text_input (str): The input text from the user to generate a response for.
- system_prompt (str): A prompt that provides context or instructions for the
response generation.
- response_model (Type[BaseModel]): A Pydantic model that defines the structure of
the expected response.
Returns:
--------
- BaseModel: Returns the generated response as an instance of the specified response
model.
- BaseModel: An instance of the specified response model containing the structured
output from the language model.
"""
try:
if response_model is str:
response_schema = {"type": "string"}
else:
response_schema = response_model
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=5,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text_input},
]
try:
response = await acompletion(
model=f"{self.model}",
messages=messages,
api_key=self.api_key,
max_completion_tokens=self.max_completion_tokens,
temperature=0.1,
response_format=response_schema,
timeout=100,
num_retries=self.MAX_RETRIES,
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
if response_model is str:
return content
return response_model.model_validate_json(content)
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
raise ValueError("Failed to get valid response after retries")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -6,7 +6,7 @@ from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
@ -56,9 +56,7 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
@ -102,6 +100,7 @@ class GenericAPIAdapter(LLMInterface):
},
],
max_retries=5,
api_key=self.api_key,
api_base=self.endpoint,
response_model=response_model,
)
@ -119,7 +118,7 @@ class GenericAPIAdapter(LLMInterface):
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error
try:
return await self.aclient.chat.completions.create(
@ -152,4 +151,4 @@ class GenericAPIAdapter(LLMInterface):
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error

View file

@ -23,6 +23,7 @@ class LLMProvider(Enum):
- ANTHROPIC: Represents the Anthropic provider.
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider.
"""
OPENAI = "openai"
@ -30,6 +31,7 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
GEMINI = "gemini"
MISTRAL = "mistral"
def get_llm_client(raise_api_key_error: bool = True):
@ -143,7 +145,36 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
streaming=llm_config.llm_streaming,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
else:

View file

@ -0,0 +1,129 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
logger = get_logger()
observe = get_observe()
class MistralAdapter(LLMInterface):
"""
Adapter for Mistral AI API, for structured output generation and prompt display.
Public methods:
- acreate_structured_output
- show_prompt
"""
name = "Mistral"
model: str
api_key: str
max_completion_tokens: int
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode.MISTRAL_TOOLS,
api_key=get_llm_config().llm_api_key,
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate a response from the user query.
Parameters:
-----------
- text_input (str): The input text from the user to be processed.
- system_prompt (str): A prompt that sets the context for the query.
- response_model (Type[BaseModel]): The model to structure the response according to
its format.
Returns:
--------
- BaseModel: An instance of BaseModel containing the structured response.
"""
try:
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": f"""Use the given format to extract information
from the following input: {text_input}""",
},
]
try:
response = await self.aclient.chat.completions.create(
model=self.model,
max_tokens=self.max_completion_tokens,
max_retries=5,
messages=messages,
response_model=response_model,
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
return response_model.model_validate_json(content)
else:
raise ValueError("Failed to get valid response after retries")
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -5,15 +5,13 @@ from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
MissingSystemPromptPathError,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
@ -29,9 +27,6 @@ observe = get_observe()
logger = get_logger()
# litellm to drop unsupported params, e.g., reasoning_effort when not supported by the model.
litellm.drop_params = True
class OpenAIAdapter(LLMInterface):
"""
@ -76,8 +71,19 @@ class OpenAIAdapter(LLMInterface):
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
@ -135,17 +141,16 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
reasoning_effort="minimal",
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
):
) as e:
if not (self.fallback_model and self.fallback_api_key):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from e
try:
return await self.aclient.chat.completions.create(
@ -178,7 +183,7 @@ class OpenAIAdapter(LLMInterface):
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error
@observe
@sleep_and_retry_sync()
@ -223,7 +228,6 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
reasoning_effort="minimal",
max_retries=self.MAX_RETRIES,
)

View file

@ -3,6 +3,7 @@ from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
# NOTE: DEPRECATED as to count tokens you need to send an API request to Google it is too slow to use with Cognee
class GeminiTokenizer(TokenizerInterface):
"""
Implements a tokenizer interface for the Gemini model, managing token extraction and
@ -16,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
llm_model: str,
max_completion_tokens: int = 3072,
):
self.model = model
self.llm_model = llm_model
self.max_completion_tokens = max_completion_tokens
# Get LLM API key from config
@ -28,12 +29,11 @@ class GeminiTokenizer(TokenizerInterface):
get_llm_config,
)
config = get_embedding_config()
llm_config = get_llm_config()
import google.generativeai as genai
from google import genai
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
self.client = genai.Client(api_key=llm_config.llm_api_key)
def extract_tokens(self, text: str) -> List[Any]:
"""
@ -77,6 +77,7 @@ class GeminiTokenizer(TokenizerInterface):
- int: The number of tokens in the given text.
"""
import google.generativeai as genai
return len(genai.embed_content(model=f"models/{self.model}", content=text))
tokens_response = self.client.models.count_tokens(model=self.llm_model, contents=text)
return tokens_response.total_tokens

View file

@ -63,7 +63,7 @@ def get_model_max_completion_tokens(model_name: str):
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
logger.debug("Model not found in LiteLLM's model_cost.")
return max_completion_tokens

View file

@ -31,6 +31,7 @@ class LoaderEngine:
"image_loader",
"audio_loader",
"unstructured_loader",
"advanced_pdf_loader",
]
def register_loader(self, loader: LoaderInterface) -> bool:
@ -86,7 +87,7 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
raise ValueError(f"Loader does not exist: {loader_name}")
logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
# Try default priority order
for loader_name in self.default_loader_priority:
@ -95,7 +96,9 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
raise ValueError(f"Loader does not exist: {loader_name}")
logger.info(
f"Skipping {loader_name}: Loader not registered (in default priority list)."
)
return None

View file

@ -20,3 +20,10 @@ try:
__all__.append("UnstructuredLoader")
except ImportError:
pass
try:
from .advanced_pdf_loader import AdvancedPdfLoader
__all__.append("AdvancedPdfLoader")
except ImportError:
pass

View file

@ -0,0 +1,244 @@
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import asyncio
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
logger = get_logger(__name__)
try:
from unstructured.partition.pdf import partition_pdf
except ImportError as e:
logger.info(
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
)
raise ImportError from e
@dataclass
class _PageBuffer:
page_num: Optional[int]
segments: List[str]
class AdvancedPdfLoader(LoaderInterface):
"""
PDF loader using unstructured library.
Extracts text content, images, tables from PDF files page by page, providing
structured page information and handling PDF-specific errors.
"""
@property
def supported_extensions(self) -> List[str]:
return ["pdf"]
@property
def supported_mime_types(self) -> List[str]:
return ["application/pdf"]
@property
def loader_name(self) -> str:
return "advanced_pdf_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
"""Check if file can be handled by this loader."""
# Check file extension
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
return True
return False
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
Args:
file_path: Path to the document file
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
**kwargs: Additional arguments passed to unstructured partition
Returns:
LoaderResult with extracted text content and metadata
"""
try:
logger.info(f"Processing PDF: {file_path}")
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
# Name ingested file of current loader based on original file content hash
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
# Set partitioning parameters
partition_kwargs: Dict[str, Any] = {
"filename": file_path,
"strategy": strategy,
"infer_table_structure": True,
"include_page_breaks": False,
"include_metadata": True,
**kwargs,
}
# Use partition to extract elements
elements = partition_pdf(**partition_kwargs)
# Process elements into text content
page_contents = self._format_elements_by_page(elements)
# Check if there is any content
if not page_contents:
logger.warning(
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
)
return await self._fallback(file_path, **kwargs)
# Combine all page outputs
full_content = "\n".join(page_contents)
# Store the content
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, full_content)
return full_file_path
except Exception as exc:
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
return await self._fallback(file_path, **kwargs)
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
logger.info("Falling back to PyPDF loader for %s", file_path)
fallback_loader = PyPdfLoader()
return await fallback_loader.load(file_path, **kwargs)
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
"""Format elements by page."""
page_buffers: List[_PageBuffer] = []
current_buffer = _PageBuffer(page_num=None, segments=[])
for element in elements:
element_dict = self._safe_to_dict(element)
metadata = element_dict.get("metadata", {})
page_num = metadata.get("page_number")
if current_buffer.page_num != page_num:
if current_buffer.segments:
page_buffers.append(current_buffer)
current_buffer = _PageBuffer(page_num=page_num, segments=[])
formatted = self._format_element(element_dict)
if formatted:
current_buffer.segments.append(formatted)
if current_buffer.segments:
page_buffers.append(current_buffer)
page_contents: List[str] = []
for buffer in page_buffers:
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
content = header + "\n\n".join(buffer.segments) + "\n"
page_contents.append(str(content))
return page_contents
def _format_element(
self,
element: Dict[str, Any],
) -> str:
"""Format element."""
element_type = element.get("type")
text = self._clean_text(element.get("text", ""))
metadata = element.get("metadata", {})
if element_type.lower() == "table":
return self._format_table_element(element) or text
if element_type.lower() == "image":
description = text or self._format_image_element(metadata)
return description
# Ignore header and footer
if element_type.lower() in ["header", "footer"]:
pass
return text
def _format_table_element(self, element: Dict[str, Any]) -> str:
"""Format table element."""
metadata = element.get("metadata", {})
text = self._clean_text(element.get("text", ""))
table_html = metadata.get("text_as_html")
if table_html:
return table_html.strip()
return text
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
"""Format image."""
placeholder = "[Image omitted]"
image_text = placeholder
coordinates = metadata.get("coordinates", {})
points = coordinates.get("points") if isinstance(coordinates, dict) else None
if points and isinstance(points, tuple) and len(points) == 4:
leftup = points[0]
rightdown = points[3]
if (
isinstance(leftup, tuple)
and isinstance(rightdown, tuple)
and len(leftup) == 2
and len(rightdown) == 2
):
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
layout_width = coordinates.get("layout_width")
layout_height = coordinates.get("layout_height")
system = coordinates.get("system")
if layout_width and layout_height and system:
image_text = (
image_text
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
)
return image_text
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
"""Safe to dict."""
try:
if hasattr(element, "to_dict"):
return element.to_dict()
except Exception:
pass
fallback_type = getattr(element, "category", None)
if not fallback_type:
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
return {
"type": fallback_type,
"text": getattr(element, "text", ""),
"metadata": getattr(element, "metadata", {}),
}
def _clean_text(self, value: Any) -> str:
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
if __name__ == "__main__":
loader = AdvancedPdfLoader()
asyncio.run(
loader.load(
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
)
)

View file

@ -16,3 +16,10 @@ try:
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
except ImportError:
pass
try:
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
except ImportError:
pass

View file

@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
async with db_engine.get_async_session() as session:
result = await session.execute(
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
select(Data)
.join(Data.datasets)
.filter((Dataset.id == dataset_id))
.order_by(Data.data_size.desc())
)
data = list(result.scalars().all())

Some files were not shown because too many files have changed in this diff Show more