Merge branch 'dev' into feature/cog-2985-add-ci-tests-that-run-more-examples
This commit is contained in:
commit
6e3370399b
154 changed files with 8669 additions and 19135 deletions
|
|
@ -16,7 +16,7 @@
|
|||
STRUCTURED_OUTPUT_FRAMEWORK="instructor"
|
||||
|
||||
LLM_API_KEY="your_api_key"
|
||||
LLM_MODEL="openai/gpt-4o-mini"
|
||||
LLM_MODEL="openai/gpt-5-mini"
|
||||
LLM_PROVIDER="openai"
|
||||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
|
|
@ -30,10 +30,13 @@ EMBEDDING_DIMENSIONS=3072
|
|||
EMBEDDING_MAX_TOKENS=8191
|
||||
# If embedding key is not provided same key set for LLM_API_KEY will be used
|
||||
#EMBEDDING_API_KEY="your_api_key"
|
||||
# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch,
|
||||
# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models
|
||||
#EMBEDDING_BATCH_SIZE=2048
|
||||
|
||||
# If using BAML structured output these env variables will be used
|
||||
BAML_LLM_PROVIDER=openai
|
||||
BAML_LLM_MODEL="gpt-4o-mini"
|
||||
BAML_LLM_MODEL="gpt-5-mini"
|
||||
BAML_LLM_ENDPOINT=""
|
||||
BAML_LLM_API_KEY="your_api_key"
|
||||
BAML_LLM_API_VERSION=""
|
||||
|
|
@ -52,18 +55,18 @@ BAML_LLM_API_VERSION=""
|
|||
################################################################################
|
||||
# Configure storage backend (local filesystem or S3)
|
||||
# STORAGE_BACKEND="local" # Default: uses local filesystem
|
||||
#
|
||||
#
|
||||
# -- To switch to S3 storage, uncomment and fill these: ---------------------
|
||||
# STORAGE_BACKEND="s3"
|
||||
# STORAGE_BUCKET_NAME="your-bucket-name"
|
||||
# AWS_REGION="us-east-1"
|
||||
# AWS_ACCESS_KEY_ID="your-access-key"
|
||||
# AWS_SECRET_ACCESS_KEY="your-secret-key"
|
||||
#
|
||||
#
|
||||
# -- S3 Root Directories (optional) -----------------------------------------
|
||||
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
|
||||
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
|
||||
#
|
||||
#
|
||||
# -- Cache Directory (auto-configured for S3) -------------------------------
|
||||
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
|
||||
# To override the automatic S3 cache location, uncomment:
|
||||
|
|
|
|||
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
|
@ -58,7 +58,7 @@ body:
|
|||
- Python version: [e.g. 3.9.0]
|
||||
- Cognee version: [e.g. 0.1.0]
|
||||
- LLM Provider: [e.g. OpenAI, Ollama]
|
||||
- Database: [e.g. Neo4j, FalkorDB]
|
||||
- Database: [e.g. Neo4j]
|
||||
validations:
|
||||
required: true
|
||||
|
||||
|
|
|
|||
2
.github/actions/cognee_setup/action.yml
vendored
2
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -41,4 +41,4 @@ runs:
|
|||
EXTRA_ARGS="$EXTRA_ARGS --extra $extra"
|
||||
done
|
||||
fi
|
||||
uv sync --extra api --extra docs --extra evals --extra gemini --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
|
||||
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
|
||||
|
|
|
|||
10
.github/workflows/db_examples_tests.yml
vendored
10
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -54,6 +54,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: dev
|
||||
|
|
@ -66,9 +70,9 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: |
|
||||
uv run python examples/database_examples/neo4j_example.py
|
||||
|
||||
|
|
|
|||
73
.github/workflows/distributed_test.yml
vendored
Normal file
73
.github/workflows/distributed_test.yml
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
name: Distributed Cognee test with modal
|
||||
permissions:
|
||||
contents: read
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
python-version:
|
||||
required: false
|
||||
type: string
|
||||
default: '3.11.x'
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
run-server-start-test:
|
||||
name: Distributed Cognee test (Modal)
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "distributed postgres"
|
||||
|
||||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
|
||||
DB_PROVIDER: "postgres"
|
||||
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
|
||||
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
|
||||
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
|
||||
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
|
||||
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
|
||||
VECTOR_DB_PROVIDER: "pgvector"
|
||||
COGNEE_DISTRIBUTED: "true"
|
||||
run: uv run modal run ./distributed/entrypoint.py
|
||||
25
.github/workflows/examples_tests.yml
vendored
25
.github/workflows/examples_tests.yml
vendored
|
|
@ -234,3 +234,28 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/permissions_example.py
|
||||
test_docling_add:
|
||||
name: Run Add with Docling Test
|
||||
runs-on: macos-15
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: 'docling'
|
||||
|
||||
- name: Run Docling Test
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_add_docling_document.py
|
||||
|
|
|
|||
10
.github/workflows/graph_db_tests.yml
vendored
10
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -71,6 +71,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
|
|
@ -83,9 +87,9 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_neo4j.py
|
||||
|
||||
- name: Run Weighted Edges Tests with Neo4j
|
||||
|
|
|
|||
|
|
@ -186,6 +186,10 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Install specific db dependency
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -206,9 +210,9 @@ jobs:
|
|||
env:
|
||||
ENV: 'dev'
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
|
|||
47
.github/workflows/search_db_tests.yml
vendored
47
.github/workflows/search_db_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
|||
name: Search test for Neo4j/LanceDB/Sqlite
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -77,6 +63,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -94,9 +84,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_search_db.py
|
||||
|
||||
run-kuzu-pgvector-postgres-search-tests:
|
||||
|
|
@ -158,19 +148,6 @@ jobs:
|
|||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
|
|
@ -196,6 +173,10 @@ jobs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -213,9 +194,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
DB_PROVIDER: 'postgres'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
24
.github/workflows/temporal_graph_tests.yml
vendored
24
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
|||
name: Temporal Graph test Neo4j (lancedb + sqlite)
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.11
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/pleaseletmein
|
||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
||||
ports:
|
||||
- 7474:7474
|
||||
- 7687:7687
|
||||
options: >-
|
||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
||||
--health-interval=10s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -77,6 +63,10 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
|
|
@ -94,9 +84,9 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
||||
GRAPH_DATABASE_USERNAME: neo4j
|
||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
run: uv run python ./cognee/tests/test_temporal_graph.py
|
||||
|
||||
run_temporal_graph_kuzu_postgres_pgvector:
|
||||
|
|
|
|||
4
.github/workflows/test_llms.yml
vendored
4
.github/workflows/test_llms.yml
vendored
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: "gemini"
|
||||
LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
LLM_MODEL: "gemini/gemini-1.5-flash"
|
||||
LLM_MODEL: "gemini/gemini-2.0-flash"
|
||||
EMBEDDING_PROVIDER: "gemini"
|
||||
EMBEDDING_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
EMBEDDING_MODEL: "gemini/text-embedding-004"
|
||||
|
|
@ -83,4 +83,4 @@ jobs:
|
|||
EMBEDDING_MODEL: "openai/text-embedding-3-large"
|
||||
EMBEDDING_DIMENSIONS: "3072"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
|
|
|||
6
.github/workflows/test_s3_file_storage.yml
vendored
6
.github/workflows/test_s3_file_storage.yml
vendored
|
|
@ -6,8 +6,12 @@ on:
|
|||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
test-gemini:
|
||||
test-s3-storage:
|
||||
name: Run S3 File Storage Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
|
|
|||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -27,6 +27,12 @@ jobs:
|
|||
uses: ./.github/workflows/e2e_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
distributed-tests:
|
||||
name: Distributed Cognee Test
|
||||
needs: [ basic-tests, e2e-tests, graph-db-tests ]
|
||||
uses: ./.github/workflows/distributed_test.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
uses: ./.github/workflows/cli_tests.yml
|
||||
|
|
@ -104,7 +110,7 @@ jobs:
|
|||
|
||||
db-examples-tests:
|
||||
name: DB Examples Tests
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
|
||||
uses: ./.github/workflows/db_examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
|
|
|
|||
7
.github/workflows/weighted_edges_tests.yml
vendored
7
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -86,12 +86,19 @@ jobs:
|
|||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Setup Neo4j with GDS
|
||||
uses: ./.github/actions/setup_neo4j
|
||||
id: neo4j
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
|
||||
run: |
|
||||
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ RUN apt-get update && apt-get install -y \
|
|||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
cmake \
|
||||
clang \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
|
@ -31,7 +32,7 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
|
|||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
|
||||
|
||||
# Copy Alembic configuration
|
||||
COPY alembic.ini /app/alembic.ini
|
||||
|
|
@ -42,7 +43,7 @@ COPY alembic/ /app/alembic
|
|||
COPY ./cognee /app/cognee
|
||||
COPY ./distributed /app/distributed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
|
||||
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,9 @@ Get started quickly with a Google Colab <a href="https://colab.research.google.
|
|||
|
||||
## About cognee
|
||||
|
||||
cognee works locally and stores your data on your device.
|
||||
Our hosted solution is just our deployment of OSS cognee on Modal, with the goal of making development and productionization easier.
|
||||
|
||||
Self-hosted package:
|
||||
|
||||
- Interconnects any kind of documents: past conversations, files, images, and audio transcriptions
|
||||
|
|
|
|||
|
|
@ -217,10 +217,24 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
|
||||
const [graphShape, setGraphShape] = useState<string>();
|
||||
|
||||
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
|
||||
durationMs?: number,
|
||||
padding?: number,
|
||||
nodeFilter?: (node: NodeObject) => boolean
|
||||
) => {
|
||||
if (!graphRef.current) {
|
||||
console.warn("GraphVisualization: graphRef not ready yet");
|
||||
return undefined as any;
|
||||
}
|
||||
|
||||
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
|
||||
};
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
zoomToFit: graphRef.current!.zoomToFit,
|
||||
setGraphShape: setGraphShape,
|
||||
zoomToFit,
|
||||
setGraphShape,
|
||||
}));
|
||||
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
|
|
|
|||
|
|
@ -89,15 +89,6 @@ export default function useChat(dataset: Dataset) {
|
|||
}
|
||||
|
||||
|
||||
interface Node {
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface Relationship {
|
||||
relationship_name: string;
|
||||
}
|
||||
|
||||
type InsightMessage = [Node, Relationship, Node];
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: string): string {
|
||||
|
|
@ -106,14 +97,6 @@ function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: strin
|
|||
}
|
||||
|
||||
switch (searchType) {
|
||||
case "INSIGHTS":
|
||||
return systemMessage.map((message: InsightMessage) => {
|
||||
const [node1, relationship, node2] = message;
|
||||
if (node1.name && node2.name) {
|
||||
return `${node1.name} ${relationship.relationship_name} ${node2.name}.`;
|
||||
}
|
||||
return "";
|
||||
}).join("\n");
|
||||
case "SUMMARIES":
|
||||
return systemMessage.map((message: { text: string }) => message.text).join("\n");
|
||||
case "CHUNKS":
|
||||
|
|
|
|||
|
|
@ -65,6 +65,9 @@ ENV PYTHONUNBUFFERED=1
|
|||
ENV MCP_LOG_LEVEL=DEBUG
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Add labels for API mode usage
|
||||
LABEL org.opencontainers.image.description="Cognee MCP Server with API mode support"
|
||||
|
||||
# Use the application name from pyproject.toml for normal operation
|
||||
# For testing, we'll override this with a direct command
|
||||
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ Build memory for Agents and query from any client that speaks MCP – in your t
|
|||
## ✨ Features
|
||||
|
||||
- Multiple transports – choose Streamable HTTP --transport http (recommended for web deployments), SSE --transport sse (real‑time streaming), or stdio (classic pipe, default)
|
||||
- Integrated logging – all actions written to a rotating file (see get_log_file_location()) and mirrored to console in dev
|
||||
- **API Mode** – connect to an already running Cognee FastAPI server instead of using cognee directly (see [API Mode](#-api-mode) below)
|
||||
- Integrated logging – all actions written to a rotating file (see get_log_file_location()) and mirrored to console in dev
|
||||
- Local file ingestion – feed .md, source files, Cursor rule‑sets, etc. straight from disk
|
||||
- Background pipelines – long‑running cognify & codify jobs spawn off‑thread; check progress with status tools
|
||||
- Developer rules bootstrap – one call indexes .cursorrules, .cursor/rules, AGENT.md, and friends into the developer_rules nodeset
|
||||
|
|
@ -91,7 +92,7 @@ To use different LLM providers / database configurations, and for more info chec
|
|||
|
||||
## 🐳 Docker Usage
|
||||
|
||||
If you’d rather run cognee-mcp in a container, you have two options:
|
||||
If you'd rather run cognee-mcp in a container, you have two options:
|
||||
|
||||
1. **Build locally**
|
||||
1. Make sure you are in /cognee root directory and have a fresh `.env` containing only your `LLM_API_KEY` (and your chosen settings).
|
||||
|
|
@ -128,6 +129,64 @@ If you’d rather run cognee-mcp in a container, you have two options:
|
|||
- ✅ Direct: `python src/server.py --transport http`
|
||||
- ❌ Direct: `-e TRANSPORT_MODE=http` (won't work)
|
||||
|
||||
### **Docker API Mode**
|
||||
|
||||
To connect the MCP Docker container to a Cognee API server running on your host machine:
|
||||
|
||||
#### **Simple Usage (Automatic localhost handling):**
|
||||
```bash
|
||||
# Start your Cognee API server on the host
|
||||
python -m cognee.api.client
|
||||
|
||||
# Run MCP container in API mode - localhost is automatically converted!
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://localhost:8000 \
|
||||
-e API_TOKEN=your_auth_token \
|
||||
-p 8001:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
**Note:** The container will automatically convert `localhost` to `host.docker.internal` on Mac/Windows/Docker Desktop. You'll see a message in the logs showing the conversion.
|
||||
|
||||
#### **Explicit host.docker.internal (Mac/Windows):**
|
||||
```bash
|
||||
# Or explicitly use host.docker.internal
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://host.docker.internal:8000 \
|
||||
-e API_TOKEN=your_auth_token \
|
||||
-p 8001:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
#### **On Linux (use host network or container IP):**
|
||||
```bash
|
||||
# Option 1: Use host network (simplest)
|
||||
docker run \
|
||||
--network host \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://localhost:8000 \
|
||||
-e API_TOKEN=your_auth_token \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
|
||||
# Option 2: Use host IP address
|
||||
# First, get your host IP: ip addr show docker0
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://172.17.0.1:8000 \
|
||||
-e API_TOKEN=your_auth_token \
|
||||
-p 8001:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
**Environment variables for API mode:**
|
||||
- `API_URL`: URL of the running Cognee API server
|
||||
- `API_TOKEN`: Authentication token (optional, required if API has authentication enabled)
|
||||
|
||||
**Note:** When running in API mode:
|
||||
- Database migrations are automatically skipped (API server handles its own DB)
|
||||
- Some features are limited (see [API Mode Limitations](#-api-mode))
|
||||
|
||||
|
||||
## 🔗 MCP Client Configuration
|
||||
|
||||
|
|
@ -255,6 +314,76 @@ You can configure both transports simultaneously for testing:
|
|||
|
||||
**Note:** Only enable the server you're actually running to avoid connection errors.
|
||||
|
||||
## 🌐 API Mode
|
||||
|
||||
The MCP server can operate in two modes:
|
||||
|
||||
### **Direct Mode** (Default)
|
||||
The MCP server directly imports and uses the cognee library. This is the default mode with full feature support.
|
||||
|
||||
### **API Mode**
|
||||
The MCP server connects to an already running Cognee FastAPI server via HTTP requests. This is useful when:
|
||||
- You have a centralized Cognee API server running
|
||||
- You want to separate the MCP server from the knowledge graph backend
|
||||
- You need multiple MCP servers to share the same knowledge graph
|
||||
|
||||
**Starting the MCP server in API mode:**
|
||||
```bash
|
||||
# Start your Cognee FastAPI server first (default port 8000)
|
||||
cd /path/to/cognee
|
||||
python -m cognee.api.client
|
||||
|
||||
# Then start the MCP server in API mode
|
||||
cd cognee-mcp
|
||||
python src/server.py --api-url http://localhost:8000 --api-token YOUR_AUTH_TOKEN
|
||||
```
|
||||
|
||||
**API Mode with different transports:**
|
||||
```bash
|
||||
# With SSE transport
|
||||
python src/server.py --transport sse --api-url http://localhost:8000 --api-token YOUR_TOKEN
|
||||
|
||||
# With HTTP transport
|
||||
python src/server.py --transport http --api-url http://localhost:8000 --api-token YOUR_TOKEN
|
||||
```
|
||||
|
||||
**API Mode with Docker:**
|
||||
```bash
|
||||
# On Mac/Windows (use host.docker.internal to access host)
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://host.docker.internal:8000 \
|
||||
-e API_TOKEN=YOUR_TOKEN \
|
||||
-p 8001:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
|
||||
# On Linux (use host network)
|
||||
docker run \
|
||||
--network host \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e API_URL=http://localhost:8000 \
|
||||
-e API_TOKEN=YOUR_TOKEN \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
**Command-line arguments for API mode:**
|
||||
- `--api-url`: Base URL of the running Cognee FastAPI server (e.g., `http://localhost:8000`)
|
||||
- `--api-token`: Authentication token for the API (optional, required if API has authentication enabled)
|
||||
|
||||
**Docker environment variables for API mode:**
|
||||
- `API_URL`: Base URL of the running Cognee FastAPI server
|
||||
- `API_TOKEN`: Authentication token (optional, required if API has authentication enabled)
|
||||
|
||||
**API Mode limitations:**
|
||||
Some features are only available in direct mode:
|
||||
- `codify` (code graph pipeline)
|
||||
- `cognify_status` / `codify_status` (pipeline status tracking)
|
||||
- `prune` (data reset)
|
||||
- `get_developer_rules` (developer rules retrieval)
|
||||
- `list_data` with specific dataset_id (detailed data listing)
|
||||
|
||||
Basic operations like `cognify`, `search`, `delete`, and `list_data` (all datasets) work in both modes.
|
||||
|
||||
## 💻 Basic Usage
|
||||
|
||||
The MCP server exposes its functionality through tools. Call them from any MCP client (Cursor, Claude Desktop, Cline, Roo and more).
|
||||
|
|
@ -266,7 +395,7 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
|
|||
|
||||
- **codify**: Analyse a code repository, build a code graph, stores it in memory
|
||||
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, INSIGHTS
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
|
||||
|
||||
- **list_data**: List all datasets and their data items with IDs for deletion operations
|
||||
|
||||
|
|
|
|||
|
|
@ -14,61 +14,94 @@ HTTP_PORT=${HTTP_PORT:-8000}
|
|||
echo "Debug port: $DEBUG_PORT"
|
||||
echo "HTTP port: $HTTP_PORT"
|
||||
|
||||
# Run Alembic migrations with proper error handling.
|
||||
# Note on UserAlreadyExists error handling:
|
||||
# During database migrations, we attempt to create a default user. If this user
|
||||
# already exists (e.g., from a previous deployment or migration), it's not a
|
||||
# critical error and shouldn't prevent the application from starting. This is
|
||||
# different from other migration errors which could indicate database schema
|
||||
# inconsistencies and should cause the startup to fail. This check allows for
|
||||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
# Check if API mode is enabled
|
||||
if [ -n "$API_URL" ]; then
|
||||
echo "API mode enabled: $API_URL"
|
||||
echo "Skipping database migrations (API server handles its own database)"
|
||||
else
|
||||
echo "Direct mode: Using local cognee instance"
|
||||
# Run Alembic migrations with proper error handling.
|
||||
# Note on UserAlreadyExists error handling:
|
||||
# During database migrations, we attempt to create a default user. If this user
|
||||
# already exists (e.g., from a previous deployment or migration), it's not a
|
||||
# critical error and shouldn't prevent the application from starting. This is
|
||||
# different from other migration errors which could indicate database schema
|
||||
# inconsistencies and should cause the startup to fail. This check allows for
|
||||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head)
|
||||
MIGRATION_EXIT_CODE=$?
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head)
|
||||
MIGRATION_EXIT_CODE=$?
|
||||
|
||||
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
|
||||
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
echo "Migration failed with unexpected error."
|
||||
exit 1
|
||||
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
|
||||
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
echo "Migration failed with unexpected error."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Database migrations done."
|
||||
echo "Database migrations done."
|
||||
fi
|
||||
|
||||
echo "Starting Cognee MCP Server with transport mode: $TRANSPORT_MODE"
|
||||
|
||||
# Add startup delay to ensure DB is ready
|
||||
sleep 2
|
||||
|
||||
# Build API arguments if API_URL is set
|
||||
API_ARGS=""
|
||||
if [ -n "$API_URL" ]; then
|
||||
# Handle localhost in API_URL - convert to host-accessible address
|
||||
if echo "$API_URL" | grep -q "localhost" || echo "$API_URL" | grep -q "127.0.0.1"; then
|
||||
echo "⚠️ Warning: API_URL contains localhost/127.0.0.1"
|
||||
echo " Original: $API_URL"
|
||||
|
||||
# Try to use host.docker.internal (works on Mac/Windows and recent Linux with Docker Desktop)
|
||||
FIXED_API_URL=$(echo "$API_URL" | sed 's/localhost/host.docker.internal/g' | sed 's/127\.0\.0\.1/host.docker.internal/g')
|
||||
|
||||
echo " Converted to: $FIXED_API_URL"
|
||||
echo " This will work on Mac/Windows/Docker Desktop."
|
||||
echo " On Linux without Docker Desktop, you may need to:"
|
||||
echo " - Use --network host, OR"
|
||||
echo " - Set API_URL=http://172.17.0.1:8000 (Docker bridge IP)"
|
||||
|
||||
API_URL="$FIXED_API_URL"
|
||||
fi
|
||||
|
||||
API_ARGS="--api-url $API_URL"
|
||||
if [ -n "$API_TOKEN" ]; then
|
||||
API_ARGS="$API_ARGS --api-token $API_TOKEN"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Modified startup with transport mode selection and error handling
|
||||
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
||||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
else
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration $API_ARGS
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
else
|
||||
exec cognee-mcp --transport stdio --no-migration
|
||||
exec cognee-mcp --transport stdio --no-migration $API_ARGS
|
||||
fi
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
|
||||
else
|
||||
exec cognee-mcp --transport stdio --no-migration
|
||||
exec cognee-mcp --transport stdio --no-migration $API_ARGS
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -8,10 +8,12 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
"httpx>=0.27.0,<1.0.0",
|
||||
]
|
||||
|
||||
authors = [
|
||||
|
|
@ -36,4 +38,5 @@ dev = [
|
|||
allow-direct-references = true
|
||||
|
||||
[project.scripts]
|
||||
cognee-mcp = "src:main"
|
||||
cognee = "src:main"
|
||||
cognee-mcp = "src:main_mcp"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,60 @@
|
|||
from .server import main as server_main
|
||||
try:
|
||||
from .server import main as server_main
|
||||
except ImportError:
|
||||
from server import main as server_main
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the package."""
|
||||
"""Deprecated main entry point for the package."""
|
||||
import asyncio
|
||||
|
||||
deprecation_notice = """
|
||||
DEPRECATION NOTICE
|
||||
The CLI entry-point used to start the Cognee MCP service has been renamed from
|
||||
"cognee" to "cognee-mcp". Calling the old entry-point will stop working in a
|
||||
future release.
|
||||
|
||||
WHAT YOU NEED TO DO:
|
||||
Locate every place where you launch the MCP process and replace the final
|
||||
argument cognee → cognee-mcp.
|
||||
|
||||
For the example mcpServers block from Cursor shown below the change is:
|
||||
{
|
||||
"mcpServers": {
|
||||
"Cognee": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"--directory",
|
||||
"/path/to/cognee-mcp",
|
||||
"run",
|
||||
"cognee" // <-- CHANGE THIS to "cognee-mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Continuing to use the old "cognee" entry-point will result in failures once it
|
||||
is removed, so please update your configuration and any shell scripts as soon
|
||||
as possible.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"The 'cognee' command for cognee-mcp is deprecated and will be removed in a future version. "
|
||||
"Please use 'cognee-mcp' instead to avoid conflicts with the main cognee library.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
print("⚠️ DEPRECATION WARNING", file=sys.stderr)
|
||||
print(deprecation_notice, file=sys.stderr)
|
||||
|
||||
asyncio.run(server_main())
|
||||
|
||||
|
||||
def main_mcp():
|
||||
"""Clean main entry point for cognee-mcp command."""
|
||||
import asyncio
|
||||
|
||||
asyncio.run(server_main())
|
||||
|
|
|
|||
|
|
@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
|
|||
|
||||
if len(edges_to_save) > 0:
|
||||
await graph_engine.add_edges(edges_to_save)
|
||||
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges_to_save)
|
||||
|
|
|
|||
338
cognee-mcp/src/cognee_client.py
Normal file
338
cognee-mcp/src/cognee_client.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""
|
||||
Cognee Client abstraction that supports both direct function calls and HTTP API calls.
|
||||
|
||||
This module provides a unified interface for interacting with Cognee, supporting:
|
||||
- Direct mode: Directly imports and calls cognee functions (default behavior)
|
||||
- API mode: Makes HTTP requests to a running Cognee FastAPI server
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Optional, Any, List, Dict
|
||||
from uuid import UUID
|
||||
from contextlib import redirect_stdout
|
||||
import httpx
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import json
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CogneeClient:
|
||||
"""
|
||||
Unified client for interacting with Cognee via direct calls or HTTP API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
api_url : str, optional
|
||||
Base URL of the Cognee API server (e.g., "http://localhost:8000").
|
||||
If None, uses direct cognee function calls.
|
||||
api_token : str, optional
|
||||
Authentication token for the API (optional, required if API has authentication enabled).
|
||||
"""
|
||||
|
||||
def __init__(self, api_url: Optional[str] = None, api_token: Optional[str] = None):
|
||||
self.api_url = api_url.rstrip("/") if api_url else None
|
||||
self.api_token = api_token
|
||||
self.use_api = bool(api_url)
|
||||
|
||||
if self.use_api:
|
||||
logger.info(f"Cognee client initialized in API mode: {self.api_url}")
|
||||
self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long operations
|
||||
else:
|
||||
logger.info("Cognee client initialized in direct mode")
|
||||
# Import cognee only if we're using direct mode
|
||||
import cognee as _cognee
|
||||
|
||||
self.cognee = _cognee
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers for API requests."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_token:
|
||||
headers["Authorization"] = f"Bearer {self.api_token}"
|
||||
return headers
|
||||
|
||||
async def add(
|
||||
self, data: Any, dataset_name: str = "main_dataset", node_set: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add data to Cognee for processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Any
|
||||
Data to add (text, file path, etc.)
|
||||
dataset_name : str
|
||||
Name of the dataset to add data to
|
||||
node_set : List[str], optional
|
||||
List of node identifiers for graph organization
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Result of the add operation
|
||||
"""
|
||||
if self.use_api:
|
||||
endpoint = f"{self.api_url}/api/v1/add"
|
||||
|
||||
files = {"data": ("data.txt", str(data), "text/plain")}
|
||||
form_data = {
|
||||
"datasetName": dataset_name,
|
||||
}
|
||||
if node_set is not None:
|
||||
form_data["node_set"] = json.dumps(node_set)
|
||||
|
||||
response = await self.client.post(
|
||||
endpoint,
|
||||
files=files,
|
||||
data=form_data,
|
||||
headers={"Authorization": f"Bearer {self.api_token}"} if self.api_token else {},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
with redirect_stdout(sys.stderr):
|
||||
await self.cognee.add(data, dataset_name=dataset_name, node_set=node_set)
|
||||
return {"status": "success", "message": "Data added successfully"}
|
||||
|
||||
async def cognify(
|
||||
self,
|
||||
datasets: Optional[List[str]] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
graph_model: Any = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform data into a knowledge graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datasets : List[str], optional
|
||||
List of dataset names to process
|
||||
custom_prompt : str, optional
|
||||
Custom prompt for entity extraction
|
||||
graph_model : Any, optional
|
||||
Custom graph model (only used in direct mode)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Result of the cognify operation
|
||||
"""
|
||||
if self.use_api:
|
||||
# API mode: Make HTTP request
|
||||
endpoint = f"{self.api_url}/api/v1/cognify"
|
||||
payload = {
|
||||
"datasets": datasets or ["main_dataset"],
|
||||
"run_in_background": False,
|
||||
}
|
||||
if custom_prompt:
|
||||
payload["custom_prompt"] = custom_prompt
|
||||
|
||||
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
with redirect_stdout(sys.stderr):
|
||||
kwargs = {}
|
||||
if datasets:
|
||||
kwargs["datasets"] = datasets
|
||||
if custom_prompt:
|
||||
kwargs["custom_prompt"] = custom_prompt
|
||||
if graph_model:
|
||||
kwargs["graph_model"] = graph_model
|
||||
|
||||
await self.cognee.cognify(**kwargs)
|
||||
return {"status": "success", "message": "Cognify completed successfully"}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
query_type: str,
|
||||
datasets: Optional[List[str]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
) -> Any:
|
||||
"""
|
||||
Search the knowledge graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_text : str
|
||||
The search query
|
||||
query_type : str
|
||||
Type of search (e.g., "GRAPH_COMPLETION", "INSIGHTS", etc.)
|
||||
datasets : List[str], optional
|
||||
List of datasets to search
|
||||
system_prompt : str, optional
|
||||
System prompt for completion searches
|
||||
top_k : int
|
||||
Maximum number of results
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Search results
|
||||
"""
|
||||
if self.use_api:
|
||||
# API mode: Make HTTP request
|
||||
endpoint = f"{self.api_url}/api/v1/search"
|
||||
payload = {"query": query_text, "search_type": query_type.upper(), "top_k": top_k}
|
||||
if datasets:
|
||||
payload["datasets"] = datasets
|
||||
if system_prompt:
|
||||
payload["system_prompt"] = system_prompt
|
||||
|
||||
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
with redirect_stdout(sys.stderr):
|
||||
results = await self.cognee.search(
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text
|
||||
)
|
||||
return results
|
||||
|
||||
async def delete(self, data_id: UUID, dataset_id: UUID, mode: str = "soft") -> Dict[str, Any]:
|
||||
"""
|
||||
Delete data from a dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_id : UUID
|
||||
ID of the data to delete
|
||||
dataset_id : UUID
|
||||
ID of the dataset containing the data
|
||||
mode : str
|
||||
Deletion mode ("soft" or "hard")
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Result of the deletion
|
||||
"""
|
||||
if self.use_api:
|
||||
# API mode: Make HTTP request
|
||||
endpoint = f"{self.api_url}/api/v1/delete"
|
||||
params = {"data_id": str(data_id), "dataset_id": str(dataset_id), "mode": mode}
|
||||
|
||||
response = await self.client.delete(
|
||||
endpoint, params=params, headers=self._get_headers()
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
result = await self.cognee.delete(
|
||||
data_id=data_id, dataset_id=dataset_id, mode=mode, user=user
|
||||
)
|
||||
return result
|
||||
|
||||
async def prune_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Prune all data from the knowledge graph.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Result of the prune operation
|
||||
"""
|
||||
if self.use_api:
|
||||
# Note: The API doesn't expose a prune endpoint, so we'll need to handle this
|
||||
# For now, raise an error
|
||||
raise NotImplementedError("Prune operation is not available via API")
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
with redirect_stdout(sys.stderr):
|
||||
await self.cognee.prune.prune_data()
|
||||
return {"status": "success", "message": "Data pruned successfully"}
|
||||
|
||||
async def prune_system(self, metadata: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Prune system data from the knowledge graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : bool
|
||||
Whether to prune metadata
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Result of the prune operation
|
||||
"""
|
||||
if self.use_api:
|
||||
# Note: The API doesn't expose a prune endpoint
|
||||
raise NotImplementedError("Prune system operation is not available via API")
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
with redirect_stdout(sys.stderr):
|
||||
await self.cognee.prune.prune_system(metadata=metadata)
|
||||
return {"status": "success", "message": "System pruned successfully"}
|
||||
|
||||
async def get_pipeline_status(self, dataset_ids: List[UUID], pipeline_name: str) -> str:
|
||||
"""
|
||||
Get the status of a pipeline run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_ids : List[UUID]
|
||||
List of dataset IDs
|
||||
pipeline_name : str
|
||||
Name of the pipeline
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Status information
|
||||
"""
|
||||
if self.use_api:
|
||||
# Note: This would need a custom endpoint on the API side
|
||||
raise NotImplementedError("Pipeline status is not available via API")
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
|
||||
with redirect_stdout(sys.stderr):
|
||||
status = await get_pipeline_status(dataset_ids, pipeline_name)
|
||||
return str(status)
|
||||
|
||||
async def list_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all datasets.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Dict[str, Any]]
|
||||
List of datasets
|
||||
"""
|
||||
if self.use_api:
|
||||
# API mode: Make HTTP request
|
||||
endpoint = f"{self.api_url}/api/v1/datasets"
|
||||
response = await self.client.get(endpoint, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
# Direct mode: Call cognee directly
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
datasets = await get_datasets(user.id)
|
||||
return [
|
||||
{"id": str(d.id), "name": d.name, "created_at": str(d.created_at)}
|
||||
for d in datasets
|
||||
]
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client if in API mode."""
|
||||
if self.use_api and hasattr(self, "client"):
|
||||
await self.client.aclose()
|
||||
|
|
@ -2,28 +2,27 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import cognee
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
|
||||
import importlib.util
|
||||
from contextlib import redirect_stdout
|
||||
import mcp.types as types
|
||||
from mcp.server import FastMCP
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
try:
|
||||
from .cognee_client import CogneeClient
|
||||
except ImportError:
|
||||
from cognee_client import CogneeClient
|
||||
|
||||
|
||||
try:
|
||||
from cognee.tasks.codingagents.coding_rule_associations import (
|
||||
|
|
@ -41,6 +40,8 @@ mcp = FastMCP("Cognee")
|
|||
|
||||
logger = get_logger()
|
||||
|
||||
cognee_client: Optional[CogneeClient] = None
|
||||
|
||||
|
||||
async def run_sse_with_cors():
|
||||
"""Custom SSE transport with CORS middleware."""
|
||||
|
|
@ -141,11 +142,20 @@ async def cognee_add_developer_rules(
|
|||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee.add(file_path, node_set=["developer_rules"])
|
||||
model = KnowledgeGraph
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
await cognee.cognify(graph_model=model)
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
|
|
@ -255,7 +265,7 @@ async def cognify(
|
|||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
|
|
@ -293,15 +303,20 @@ async def cognify(
|
|||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Cognify process starting.")
|
||||
if graph_model_file and graph_model_name:
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
await cognee.add(data)
|
||||
graph_model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Custom graph models are not supported in API mode, ignoring.")
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.add(data)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
|
||||
await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model)
|
||||
logger.info("Cognify process finished.")
|
||||
except Exception as e:
|
||||
logger.error("Cognify process failed.")
|
||||
|
|
@ -354,16 +369,19 @@ async def save_interaction(data: str) -> list:
|
|||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Save interaction process starting.")
|
||||
|
||||
await cognee.add(data, node_set=["user_agent_interaction"])
|
||||
await cognee_client.add(data, node_set=["user_agent_interaction"])
|
||||
|
||||
try:
|
||||
await cognee.cognify()
|
||||
await cognee_client.cognify()
|
||||
logger.info("Save interaction process finished.")
|
||||
logger.info("Generating associated rules from interaction data.")
|
||||
|
||||
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
|
||||
|
||||
logger.info("Associated rules generated from interaction data.")
|
||||
# Rule associations only work in direct mode
|
||||
if not cognee_client.use_api:
|
||||
logger.info("Generating associated rules from interaction data.")
|
||||
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
|
||||
logger.info("Associated rules generated from interaction data.")
|
||||
else:
|
||||
logger.warning("Rule associations are not available in API mode, skipping.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Save interaction process failed.")
|
||||
|
|
@ -420,11 +438,18 @@ async def codify(repo_path: str) -> list:
|
|||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
|
|
@ -478,11 +503,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
Best for: Direct document retrieval, specific fact-finding.
|
||||
Returns: LLM responses based on relevant text chunks.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
Returns: Formatted relationship data and entity connections.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -524,7 +544,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
- "RAG_COMPLETION": Returns an LLM response based on the search query and standard RAG data
|
||||
- "CODE": Returns code-related knowledge in JSON format
|
||||
- "CHUNKS": Returns raw text chunks from the knowledge graph
|
||||
- "INSIGHTS": Returns relationships between nodes in readable format
|
||||
- "SUMMARIES": Returns pre-generated hierarchical summaries
|
||||
- "CYPHER": Direct graph database queries
|
||||
- "FEELING_LUCKY": Automatically selects best search type
|
||||
|
|
@ -537,7 +556,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
A list containing a single TextContent object with the search results.
|
||||
The format of the result depends on the search_type:
|
||||
- **GRAPH_COMPLETION/RAG_COMPLETION**: Conversational AI response strings
|
||||
- **INSIGHTS**: Formatted relationship descriptions and entity connections
|
||||
- **CHUNKS**: Relevant text passages with source metadata
|
||||
- **SUMMARIES**: Hierarchical summaries from general to specific
|
||||
- **CODE**: Structured code information with context
|
||||
|
|
@ -547,7 +565,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
Performance & Optimization:
|
||||
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
|
||||
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
|
||||
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
|
||||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
|
|
@ -574,23 +591,40 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType[search_type.upper()], query_text=search_query
|
||||
search_results = await cognee_client.search(
|
||||
query_text=search_query, query_type=search_type
|
||||
)
|
||||
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
elif (
|
||||
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
|
||||
):
|
||||
return str(search_results[0])
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
# Handle different result formats based on API vs direct mode
|
||||
if cognee_client.use_api:
|
||||
# API mode returns JSON-serialized results
|
||||
if isinstance(search_results, str):
|
||||
return search_results
|
||||
elif isinstance(search_results, list):
|
||||
if (
|
||||
search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"]
|
||||
and len(search_results) > 0
|
||||
):
|
||||
return str(search_results[0])
|
||||
return str(search_results)
|
||||
else:
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
else:
|
||||
return str(search_results)
|
||||
# Direct mode processing
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
elif (
|
||||
search_type.upper() == "GRAPH_COMPLETION"
|
||||
or search_type.upper() == "RAG_COMPLETION"
|
||||
):
|
||||
return str(search_results[0])
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
else:
|
||||
return str(search_results)
|
||||
|
||||
search_results = await search_task(search_query, search_type)
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
|
@ -623,6 +657,10 @@ async def get_developer_rules() -> list:
|
|||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
|
|
@ -662,16 +700,24 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
user = await get_default_user()
|
||||
output_lines = []
|
||||
|
||||
if dataset_id:
|
||||
# List data for specific dataset
|
||||
# Detailed data listing for specific dataset is only available in direct mode
|
||||
if cognee_client.use_api:
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.",
|
||||
)
|
||||
]
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
|
||||
logger.info(f"Listing data for dataset: {dataset_id}")
|
||||
dataset_uuid = UUID(dataset_id)
|
||||
|
||||
# Get the dataset information
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
user = await get_default_user()
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_uuid)
|
||||
|
||||
|
|
@ -700,11 +746,9 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
output_lines.append(" (No data items in this dataset)")
|
||||
|
||||
else:
|
||||
# List all datasets
|
||||
# List all datasets - works in both modes
|
||||
logger.info("Listing all datasets")
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
|
||||
datasets = await get_datasets(user.id)
|
||||
datasets = await cognee_client.list_datasets()
|
||||
|
||||
if not datasets:
|
||||
return [
|
||||
|
|
@ -719,20 +763,21 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
output_lines.append("")
|
||||
|
||||
for i, dataset in enumerate(datasets, 1):
|
||||
# Get data count for each dataset
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
|
||||
data_items = await get_dataset_data(dataset.id)
|
||||
|
||||
output_lines.append(f"{i}. 📁 {dataset.name}")
|
||||
output_lines.append(f" Dataset ID: {dataset.id}")
|
||||
output_lines.append(f" Created: {dataset.created_at}")
|
||||
output_lines.append(f" Data items: {len(data_items)}")
|
||||
# In API mode, dataset is a dict; in direct mode, it's formatted as dict
|
||||
if isinstance(dataset, dict):
|
||||
output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}")
|
||||
output_lines.append(f" Dataset ID: {dataset.get('id')}")
|
||||
output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}")
|
||||
else:
|
||||
output_lines.append(f"{i}. 📁 {dataset.name}")
|
||||
output_lines.append(f" Dataset ID: {dataset.id}")
|
||||
output_lines.append(f" Created: {dataset.created_at}")
|
||||
output_lines.append("")
|
||||
|
||||
output_lines.append("💡 To see data items in a specific dataset, use:")
|
||||
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
|
||||
output_lines.append("")
|
||||
if not cognee_client.use_api:
|
||||
output_lines.append("💡 To see data items in a specific dataset, use:")
|
||||
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
|
||||
output_lines.append("")
|
||||
output_lines.append("🗑️ To delete specific data, use:")
|
||||
output_lines.append(' delete(data_id="data-id", dataset_id="dataset-id")')
|
||||
|
||||
|
|
@ -801,12 +846,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
|
|||
data_uuid = UUID(data_id)
|
||||
dataset_uuid = UUID(dataset_id)
|
||||
|
||||
# Get default user for the operation
|
||||
user = await get_default_user()
|
||||
|
||||
# Call the cognee delete function
|
||||
result = await cognee.delete(
|
||||
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user
|
||||
# Call the cognee delete function via client
|
||||
result = await cognee_client.delete(
|
||||
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode
|
||||
)
|
||||
|
||||
logger.info(f"Delete operation completed successfully: {result}")
|
||||
|
|
@ -853,11 +895,21 @@ async def prune():
|
|||
-----
|
||||
- This operation cannot be undone. All memory data will be permanently deleted.
|
||||
- The function prunes both data content (using prune_data) and system metadata (using prune_system)
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
try:
|
||||
await cognee_client.prune_data()
|
||||
await cognee_client.prune_system(metadata=True)
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Prune operation is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Prune operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -880,13 +932,26 @@ async def cognify_status():
|
|||
- The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset"
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get cognify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -909,13 +974,26 @@ async def codify_status():
|
|||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
|
|
@ -949,6 +1027,8 @@ def load_class(model_file, model_name):
|
|||
|
||||
|
||||
async def main():
|
||||
global cognee_client
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -992,12 +1072,30 @@ async def main():
|
|||
help="Argument stops database migration from being attempted",
|
||||
)
|
||||
|
||||
# Cognee API connection options
|
||||
parser.add_argument(
|
||||
"--api-url",
|
||||
default=None,
|
||||
help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). "
|
||||
"If provided, the MCP server will connect to the API instead of using cognee directly.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--api-token",
|
||||
default=None,
|
||||
help="Authentication token for the API (optional, required if API has authentication enabled).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the global CogneeClient
|
||||
cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token)
|
||||
|
||||
mcp.settings.host = args.host
|
||||
mcp.settings.port = args.port
|
||||
|
||||
if not args.no_migration:
|
||||
# Skip migrations when in API mode (the API server handles its own database)
|
||||
if not args.no_migration and not args.api_url:
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
|
|
@ -1020,6 +1118,8 @@ async def main():
|
|||
sys.exit(1)
|
||||
|
||||
logger.info("Database migrations done.")
|
||||
elif args.api_url:
|
||||
logger.info("Skipping database migrations (using API mode)")
|
||||
|
||||
logger.info(f"Starting MCP server with transport: {args.transport}")
|
||||
if args.transport == "stdio":
|
||||
|
|
|
|||
2831
cognee-mcp/uv.lock
generated
2831
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -19,6 +19,7 @@ from .api.v1.add import add
|
|||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .modules.memify import memify
|
||||
from .api.v1.update import update
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.prune import prune
|
||||
|
|
|
|||
|
|
@ -189,12 +189,12 @@ class HealthChecker:
|
|||
start_time = time.time()
|
||||
try:
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
# Test actual API connection with minimal request
|
||||
LLMGateway.show_prompt("test", "test.txt")
|
||||
from cognee.infrastructure.llm.utils import test_llm_connection
|
||||
|
||||
await test_llm_connection()
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
return ComponentHealth(
|
||||
|
|
@ -217,13 +217,9 @@ class HealthChecker:
|
|||
"""Check embedding service health (non-critical)."""
|
||||
start_time = time.time()
|
||||
try:
|
||||
from cognee.infrastructure.databases.vector.embeddings.get_embedding_engine import (
|
||||
get_embedding_engine,
|
||||
)
|
||||
from cognee.infrastructure.llm.utils import test_embedding_connection
|
||||
|
||||
# Test actual embedding generation with minimal text
|
||||
engine = get_embedding_engine()
|
||||
await engine.embed_text(["test"])
|
||||
await test_embedding_connection()
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
return ComponentHealth(
|
||||
|
|
@ -245,16 +241,6 @@ class HealthChecker:
|
|||
"""Get comprehensive health status."""
|
||||
components = {}
|
||||
|
||||
# Critical services
|
||||
critical_components = [
|
||||
"relational_db",
|
||||
"vector_db",
|
||||
"graph_db",
|
||||
"file_storage",
|
||||
"llm_provider",
|
||||
"embedding_service",
|
||||
]
|
||||
|
||||
critical_checks = [
|
||||
("relational_db", self.check_relational_db()),
|
||||
("vector_db", self.check_vector_db()),
|
||||
|
|
@ -300,11 +286,11 @@ class HealthChecker:
|
|||
else:
|
||||
components[name] = result
|
||||
|
||||
critical_comps = [check[0] for check in critical_checks]
|
||||
# Determine overall status
|
||||
critical_unhealthy = any(
|
||||
comp.status == HealthStatus.UNHEALTHY
|
||||
comp.status == HealthStatus.UNHEALTHY and name in critical_comps
|
||||
for name, comp in components.items()
|
||||
if name in critical_components
|
||||
)
|
||||
|
||||
has_degraded = any(comp.status == HealthStatus.DEGRADED for comp in components.values())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, List, Optional
|
||||
|
||||
import os
|
||||
from typing import Union, BinaryIO, List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from urllib.parse import urlparse
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import Task, run_pipeline
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
|
||||
|
|
@ -11,6 +13,19 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|||
)
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
|
||||
from cognee.context_global_variables import (
|
||||
tavily_config as tavily,
|
||||
soup_crawler_config as soup_crawler,
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
|
||||
|
||||
async def add(
|
||||
|
|
@ -23,12 +38,15 @@ async def add(
|
|||
dataset_id: Optional[UUID] = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
incremental_loading: bool = True,
|
||||
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||
tavily_config: Optional[BaseModel] = None,
|
||||
soup_crawler_config: Optional[BaseModel] = None,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
|
||||
This is the first step in the Cognee workflow - it ingests raw data and prepares it
|
||||
for processing. The function accepts various data formats including text, files, and
|
||||
for processing. The function accepts various data formats including text, files, urls and
|
||||
binary streams, then stores them in a specified dataset for further processing.
|
||||
|
||||
Prerequisites:
|
||||
|
|
@ -68,6 +86,7 @@ async def add(
|
|||
- S3 path: "s3://my-bucket/documents/file.pdf"
|
||||
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
|
||||
- Binary file object: open("file.txt", "rb")
|
||||
- url: A web link url (https or http)
|
||||
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
|
||||
Create separate datasets to organize different knowledge domains.
|
||||
user: User object for authentication and permissions. Uses default user if None.
|
||||
|
|
@ -78,6 +97,9 @@ async def add(
|
|||
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||
extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup
|
||||
tavily_config: Optional configuration for Tavily API, including API key and extraction settings
|
||||
soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules.
|
||||
|
||||
Returns:
|
||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||
|
|
@ -126,6 +148,21 @@ async def add(
|
|||
|
||||
# Add a single file
|
||||
await cognee.add("/home/user/documents/analysis.pdf")
|
||||
|
||||
# Add a single url and bs4 extract ingestion method
|
||||
extraction_rules = {
|
||||
"title": "h1",
|
||||
"description": "p",
|
||||
"more_info": "a[href*='more-info']"
|
||||
}
|
||||
await cognee.add("https://example.com",extraction_rules=extraction_rules)
|
||||
|
||||
# Add a single url and tavily extract ingestion method
|
||||
Make sure to set TAVILY_API_KEY = YOUR_TAVILY_API_KEY as a environment variable
|
||||
await cognee.add("https://example.com")
|
||||
|
||||
# Add multiple urls
|
||||
await cognee.add(["https://example.com","https://books.toscrape.com"])
|
||||
```
|
||||
|
||||
Environment Variables:
|
||||
|
|
@ -133,22 +170,55 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
- VECTOR_DB_PROVIDER: "lancedb" (default), "chromadb", "pgvector"
|
||||
- GRAPH_DATABASE_PROVIDER: "kuzu" (default), "neo4j"
|
||||
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
if not soup_crawler_config and extraction_rules:
|
||||
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
|
||||
if not tavily_config and os.getenv("TAVILY_API_KEY"):
|
||||
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
soup_crawler.set(soup_crawler_config)
|
||||
tavily.set(tavily_config)
|
||||
|
||||
http_schemes = {"http", "https"}
|
||||
|
||||
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
|
||||
return isinstance(item, str) and urlparse(item).scheme in http_schemes
|
||||
|
||||
if _is_http_url(data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
except NameError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
Task(ingest_data, dataset_name, user, node_set, dataset_id, preferred_loaders),
|
||||
Task(
|
||||
ingest_data,
|
||||
dataset_name,
|
||||
user,
|
||||
node_set,
|
||||
dataset_id,
|
||||
preferred_loaders,
|
||||
),
|
||||
]
|
||||
|
||||
await setup()
|
||||
|
||||
user, authorized_dataset = await resolve_authorized_user_dataset(dataset_id, dataset_name, user)
|
||||
user, authorized_dataset = await resolve_authorized_user_dataset(
|
||||
dataset_name=dataset_name, dataset_id=dataset_id, user=user
|
||||
)
|
||||
|
||||
await reset_dataset_pipeline_run_status(
|
||||
authorized_dataset.id, user, pipeline_names=["add_pipeline", "cognify_pipeline"]
|
||||
|
|
|
|||
|
|
@ -73,7 +73,11 @@ def get_add_router() -> APIRouter:
|
|||
|
||||
try:
|
||||
add_run = await cognee_add(
|
||||
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
|
||||
data,
|
||||
datasetName,
|
||||
user=user,
|
||||
dataset_id=datasetId,
|
||||
node_set=node_set if node_set else None,
|
||||
)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ async def cognify(
|
|||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
|
|||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"NATURAL_LANGUAGE",
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ async def handle_search(arguments: Dict[str, Any], user) -> list:
|
|||
valid_search_types = (
|
||||
search_tool["parameters"]["properties"]["search_type"]["enum"]
|
||||
if search_tool
|
||||
else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
|
||||
else ["CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
|
||||
)
|
||||
|
||||
if search_type_str not in valid_search_types:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
|
|||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"NATURAL_LANGUAGE",
|
||||
|
|
|
|||
|
|
@ -52,11 +52,6 @@ async def search(
|
|||
Best for: Direct document retrieval, specific fact-finding.
|
||||
Returns: LLM responses based on relevant text chunks.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
Returns: Formatted relationship data and entity connections.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -124,9 +119,6 @@ async def search(
|
|||
**GRAPH_COMPLETION/RAG_COMPLETION**:
|
||||
[List of conversational AI response strings]
|
||||
|
||||
**INSIGHTS**:
|
||||
[List of formatted relationship descriptions and entity connections]
|
||||
|
||||
**CHUNKS**:
|
||||
[List of relevant text passages with source metadata]
|
||||
|
||||
|
|
@ -146,7 +138,6 @@ async def search(
|
|||
Performance & Optimization:
|
||||
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
|
||||
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
|
||||
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
|
||||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
|
|
|
|||
|
|
@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
|
|||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
||||
provider: Union[
|
||||
Literal["openai"],
|
||||
Literal["ollama"],
|
||||
Literal["anthropic"],
|
||||
Literal["gemini"],
|
||||
Literal["mistral"],
|
||||
]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
|
|
|||
|
|
@ -502,22 +502,48 @@ def start_ui(
|
|||
|
||||
if start_mcp:
|
||||
logger.info("Starting Cognee MCP server with Docker...")
|
||||
cwd = os.getcwd()
|
||||
env_file = os.path.join(cwd, ".env")
|
||||
try:
|
||||
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
|
||||
subprocess.run(["docker", "pull", image], check=True)
|
||||
|
||||
import uuid
|
||||
|
||||
container_name = f"cognee-mcp-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
docker_cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"--name",
|
||||
container_name,
|
||||
"-p",
|
||||
f"{mcp_port}:8000",
|
||||
"--rm",
|
||||
"-e",
|
||||
"TRANSPORT_MODE=sse",
|
||||
]
|
||||
|
||||
if start_backend:
|
||||
docker_cmd.extend(
|
||||
[
|
||||
"-e",
|
||||
f"API_URL=http://localhost:{backend_port}",
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Configuring MCP to connect to backend API at http://localhost:{backend_port}"
|
||||
)
|
||||
logger.info("(localhost will be auto-converted to host.docker.internal)")
|
||||
else:
|
||||
cwd = os.getcwd()
|
||||
env_file = os.path.join(cwd, ".env")
|
||||
docker_cmd.extend(["--env-file", env_file])
|
||||
|
||||
docker_cmd.append(
|
||||
image
|
||||
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
|
||||
|
||||
mcp_process = subprocess.Popen(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-p",
|
||||
f"{mcp_port}:8000",
|
||||
"--rm",
|
||||
"--env-file",
|
||||
env_file,
|
||||
"-e",
|
||||
"TRANSPORT_MODE=sse",
|
||||
"cognee/cognee-mcp:daulet-dev",
|
||||
],
|
||||
docker_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
|
|
@ -526,8 +552,13 @@ def start_ui(
|
|||
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
|
||||
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
|
||||
|
||||
pid_callback(mcp_process.pid)
|
||||
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
|
||||
# Pass both PID and container name using a tuple
|
||||
pid_callback((mcp_process.pid, container_name))
|
||||
|
||||
mode_info = "API mode" if start_backend else "direct mode"
|
||||
logger.info(
|
||||
f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse ({mode_info})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
|
||||
# Start backend server if requested
|
||||
|
|
@ -627,7 +658,6 @@ def start_ui(
|
|||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
|
|
@ -637,7 +667,6 @@ def start_ui(
|
|||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def get_update_router() -> APIRouter:
|
|||
data=data,
|
||||
dataset_id=dataset_id,
|
||||
user=user,
|
||||
node_set=node_set,
|
||||
node_set=node_set if node_set else None,
|
||||
)
|
||||
|
||||
# If any cognify run errored return JSONResponse with proper error status code
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from cognee.api.v1.cognify import cognify
|
|||
async def update(
|
||||
data_id: UUID,
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
dataset_id: UUID,
|
||||
user: User = None,
|
||||
node_set: Optional[List[str]] = None,
|
||||
dataset_id: Optional[UUID] = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
|
|
|
|||
|
|
@ -175,19 +175,59 @@ def main() -> int:
|
|||
# Handle UI flag
|
||||
if hasattr(args, "start_ui") and args.start_ui:
|
||||
spawned_pids = []
|
||||
docker_container = None
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C and other termination signals"""
|
||||
nonlocal spawned_pids
|
||||
fmt.echo("\nShutting down UI server...")
|
||||
nonlocal spawned_pids, docker_container
|
||||
|
||||
try:
|
||||
fmt.echo("\nShutting down UI server...")
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
# First, stop Docker container if running
|
||||
if docker_container:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "stop", docker_container],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
try:
|
||||
if result.returncode == 0:
|
||||
fmt.success(f"✓ Docker container {docker_container} stopped.")
|
||||
else:
|
||||
fmt.warning(
|
||||
f"Could not stop container {docker_container}: {result.stderr.decode()}"
|
||||
)
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
fmt.warning(
|
||||
f"Timeout stopping container {docker_container}, forcing removal..."
|
||||
)
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
subprocess.run(
|
||||
["docker", "rm", "-f", docker_container], capture_output=True, check=False
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Then, stop regular processes
|
||||
for pid in spawned_pids:
|
||||
try:
|
||||
if hasattr(os, "killpg"):
|
||||
# Unix-like systems: Use process groups
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
try:
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
else:
|
||||
# Windows: Use taskkill to terminate process and its children
|
||||
subprocess.run(
|
||||
|
|
@ -195,24 +235,35 @@ def main() -> int:
|
|||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
fmt.success(f"✓ Process {pid} and its children terminated.")
|
||||
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
|
||||
fmt.warning(f"Could not terminate process {pid}: {e}")
|
||||
try:
|
||||
fmt.success(f"✓ Process {pid} and its children terminated.")
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
except (OSError, ProcessLookupError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, signal_handler)
|
||||
|
||||
try:
|
||||
from cognee import start_ui
|
||||
|
||||
fmt.echo("Starting cognee UI...")
|
||||
|
||||
# Callback to capture PIDs of all spawned processes
|
||||
def pid_callback(pid):
|
||||
nonlocal spawned_pids
|
||||
spawned_pids.append(pid)
|
||||
# Callback to capture PIDs and Docker container of all spawned processes
|
||||
def pid_callback(pid_or_tuple):
|
||||
nonlocal spawned_pids, docker_container
|
||||
# Handle both regular PIDs and (PID, container_name) tuples
|
||||
if isinstance(pid_or_tuple, tuple):
|
||||
pid, container_name = pid_or_tuple
|
||||
spawned_pids.append(pid)
|
||||
docker_container = container_name
|
||||
else:
|
||||
spawned_pids.append(pid_or_tuple)
|
||||
|
||||
frontend_port = 3000
|
||||
start_backend, backend_port = True, 8000
|
||||
|
|
|
|||
|
|
@ -70,11 +70,11 @@ After adding data, use `cognee cognify` to process it into knowledge graphs.
|
|||
await cognee.add(data=data_to_add, dataset_name=args.dataset_name)
|
||||
fmt.success(f"Successfully added data to dataset '{args.dataset_name}'")
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to add data: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to add data: {str(e)}") from e
|
||||
|
||||
asyncio.run(run_add())
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error adding data: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Failed to add data: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to cognify: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to cognify: {str(e)}") from e
|
||||
|
||||
result = asyncio.run(run_cognify())
|
||||
|
||||
|
|
@ -124,5 +124,5 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -79,8 +79,10 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error managing configuration: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(
|
||||
f"Error managing configuration: {str(e)}", error_code=1
|
||||
) from e
|
||||
|
||||
def _handle_get(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -122,7 +124,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.note("Configuration viewing not fully implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}") from e
|
||||
|
||||
def _handle_set(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -141,7 +143,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.error(f"Failed to set configuration key '{args.key}'")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}") from e
|
||||
|
||||
def _handle_unset(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -189,7 +191,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.note("Use 'cognee config list' to see all available configuration options")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}") from e
|
||||
|
||||
def _handle_list(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -209,7 +211,7 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.echo(" cognee config reset - Reset all to defaults")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}") from e
|
||||
|
||||
def _handle_reset(self, args: argparse.Namespace) -> None:
|
||||
try:
|
||||
|
|
@ -222,4 +224,4 @@ Configuration changes will affect how cognee processes and stores data.
|
|||
fmt.echo("This would reset all settings to their default values")
|
||||
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}") from e
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
|
|||
from cognee.cli import DEFAULT_DOCS_URL
|
||||
import cognee.cli.echo as fmt
|
||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
|
||||
|
||||
|
||||
class DeleteCommand(SupportsCliCommand):
|
||||
|
|
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
|
|||
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
|
||||
return
|
||||
|
||||
# Build confirmation message
|
||||
# If --force is used, skip the preview and go straight to deletion
|
||||
if not args.force:
|
||||
# --- START PREVIEW LOGIC ---
|
||||
fmt.echo("Gathering data for preview...")
|
||||
try:
|
||||
preview_data = asyncio.run(
|
||||
get_deletion_counts(
|
||||
dataset_name=args.dataset_name,
|
||||
user_id=args.user_id,
|
||||
all_data=args.all,
|
||||
)
|
||||
)
|
||||
except CliCommandException as e:
|
||||
fmt.error(f"Error occured when fetching preview data: {str(e)}")
|
||||
return
|
||||
|
||||
if not preview_data:
|
||||
fmt.success("No data found to delete.")
|
||||
return
|
||||
|
||||
fmt.echo("You are about to delete:")
|
||||
fmt.echo(
|
||||
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
||||
)
|
||||
fmt.echo("-" * 20)
|
||||
# --- END PREVIEW LOGIC ---
|
||||
|
||||
# Build operation message for success/failure logging
|
||||
if args.all:
|
||||
confirm_msg = "Delete ALL data from cognee?"
|
||||
operation = "all data"
|
||||
|
|
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
|
|||
elif args.user_id:
|
||||
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
||||
operation = f"data for user '{args.user_id}'"
|
||||
else:
|
||||
operation = "data"
|
||||
|
||||
# Confirm deletion unless forced
|
||||
if not args.force:
|
||||
fmt.warning("This operation is irreversible!")
|
||||
if not fmt.confirm(confirm_msg):
|
||||
|
|
@ -64,17 +93,20 @@ Be careful with deletion operations as they are irreversible.
|
|||
# Run the async delete function
|
||||
async def run_delete():
|
||||
try:
|
||||
# NOTE: The underlying cognee.delete() function is currently not working as expected.
|
||||
# This is a separate bug that this preview feature helps to expose.
|
||||
if args.all:
|
||||
await cognee.delete(dataset_name=None, user_id=args.user_id)
|
||||
else:
|
||||
await cognee.delete(dataset_name=args.dataset_name, user_id=args.user_id)
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to delete: {str(e)}") from e
|
||||
|
||||
asyncio.run(run_delete())
|
||||
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
|
||||
fmt.success(f"Successfully deleted {operation}")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -31,10 +31,6 @@ Search Types & Use Cases:
|
|||
Traditional RAG using document chunks without graph structure.
|
||||
Best for: Direct document retrieval, specific fact-finding.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
|
|
@ -108,7 +104,7 @@ Search Types & Use Cases:
|
|||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CliCommandInnerException(f"Failed to search: {str(e)}")
|
||||
raise CliCommandInnerException(f"Failed to search: {str(e)}") from e
|
||||
|
||||
results = asyncio.run(run_search())
|
||||
|
||||
|
|
@ -145,5 +141,5 @@ Search Types & Use Cases:
|
|||
|
||||
except Exception as e:
|
||||
if isinstance(e, CliCommandInnerException):
|
||||
raise CliCommandException(str(e), error_code=1)
|
||||
raise CliCommandException(f"Error searching: {str(e)}", error_code=1)
|
||||
raise CliCommandException(str(e), error_code=1) from e
|
||||
raise CliCommandException(f"Error searching: {str(e)}", error_code=1) from e
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ COMMAND_DESCRIPTIONS = {
|
|||
SEARCH_TYPE_CHOICES = [
|
||||
"GRAPH_COMPLETION",
|
||||
"RAG_COMPLETION",
|
||||
"INSIGHTS",
|
||||
"CHUNKS",
|
||||
"SUMMARIES",
|
||||
"CODE",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from cognee.modules.users.methods import get_user
|
|||
# for different async tasks, threads and processes
|
||||
vector_db_config = ContextVar("vector_db_config", default=None)
|
||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
|
||||
tavily_config = ContextVar("tavily_config", default=None)
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
|
|
|
|||
|
|
@ -50,26 +50,26 @@ class GraphConfig(BaseSettings):
|
|||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
# If no specific graph_filename or path are provided
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
provider = values.graph_database_provider.lower()
|
||||
def fill_derived(self):
|
||||
provider = self.graph_database_provider.lower()
|
||||
base_config = get_base_config()
|
||||
|
||||
# Set default filename if no filename is provided
|
||||
if not values.graph_filename:
|
||||
values.graph_filename = f"cognee_graph_{provider}"
|
||||
if not self.graph_filename:
|
||||
self.graph_filename = f"cognee_graph_{provider}"
|
||||
|
||||
# Handle graph file path
|
||||
if values.graph_file_path:
|
||||
if self.graph_file_path:
|
||||
# Check if absolute path is provided
|
||||
values.graph_file_path = ensure_absolute_path(
|
||||
os.path.join(values.graph_file_path, values.graph_filename)
|
||||
self.graph_file_path = ensure_absolute_path(
|
||||
os.path.join(self.graph_file_path, self.graph_filename)
|
||||
)
|
||||
else:
|
||||
# Default path
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
||||
self.graph_file_path = os.path.join(databases_directory_path, self.graph_filename)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -44,16 +44,14 @@ def create_graph_engine(
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- graph_database_provider: The type of graph database provider to use (e.g., neo4j,
|
||||
falkordb, kuzu).
|
||||
- graph_database_url: The URL for the graph database instance. Required for neo4j
|
||||
and falkordb providers.
|
||||
- graph_database_provider: The type of graph database provider to use (e.g., neo4j, falkor, kuzu).
|
||||
- graph_database_url: The URL for the graph database instance. Required for neo4j and falkordb providers.
|
||||
- graph_database_username: The username for authentication with the graph database.
|
||||
Required for neo4j provider.
|
||||
- graph_database_password: The password for authentication with the graph database.
|
||||
Required for neo4j provider.
|
||||
- graph_database_port: The port number for the graph database connection. Required
|
||||
for the falkordb provider.
|
||||
for the falkordb provider
|
||||
- graph_file_path: The filesystem path to the graph file. Required for the kuzu
|
||||
provider.
|
||||
|
||||
|
|
@ -86,21 +84,6 @@ def create_graph_engine(
|
|||
graph_database_name=graph_database_name or None,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "falkordb":
|
||||
if not (graph_database_url and graph_database_port):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=graph_database_url,
|
||||
database_port=graph_database_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "kuzu":
|
||||
if not graph_file_path:
|
||||
raise EnvironmentError("Missing required Kuzu database path.")
|
||||
|
|
@ -179,5 +162,5 @@ def create_graph_engine(
|
|||
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider: {graph_database_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,29 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
def _initialize_connection(self) -> None:
|
||||
"""Initialize the Kuzu database connection and schema."""
|
||||
|
||||
def _install_json_extension():
|
||||
"""
|
||||
Function handles installing of the json extension for the current Kuzu version.
|
||||
This has to be done with an empty graph db before connecting to an existing database otherwise
|
||||
missing json extension errors will be raised.
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
|
||||
temp_graph_file = temp_file.name
|
||||
tmp_db = Database(
|
||||
temp_graph_file,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
tmp_db.init_database()
|
||||
connection = Connection(tmp_db)
|
||||
connection.execute("INSTALL JSON;")
|
||||
except Exception as e:
|
||||
logger.info(f"JSON extension already installed or not needed: {e}")
|
||||
|
||||
_install_json_extension()
|
||||
|
||||
try:
|
||||
if "s3://" in self.db_path:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
||||
|
|
@ -109,11 +132,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
|
||||
try:
|
||||
self.connection.execute("INSTALL JSON;")
|
||||
except Exception as e:
|
||||
logger.info(f"JSON extension already installed or not needed: {e}")
|
||||
|
||||
try:
|
||||
self.connection.execute("LOAD EXTENSION JSON;")
|
||||
logger.info("Loaded JSON extension")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
auth=auth,
|
||||
max_connection_lifetime=120,
|
||||
notifications_min_severity="OFF",
|
||||
keep_alive=True,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
"properties": self.serialize_properties(dict(node)),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
|||
logger = get_logger("deadlock_retry")
|
||||
|
||||
|
||||
def deadlock_retry(max_retries=5):
|
||||
def deadlock_retry(max_retries=10):
|
||||
"""
|
||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
|
|||
return graph_id, region
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") from e
|
||||
|
||||
|
||||
def validate_graph_id(graph_id: str) -> bool:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -23,14 +23,14 @@ class RelationalConfig(BaseSettings):
|
|||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
def fill_derived(self):
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.db_path:
|
||||
if not self.db_path:
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.db_path = databases_directory_path
|
||||
self.db_path = databases_directory_path
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class SQLAlchemyAdapter:
|
|||
try:
|
||||
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
|
||||
except (ValueError, NoResultFound) as e:
|
||||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
|
||||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}") from e
|
||||
|
||||
# Check if other data objects point to the same raw data location
|
||||
raw_data_location_entities = (
|
||||
|
|
|
|||
|
|
@ -30,21 +30,21 @@ class VectorConfig(BaseSettings):
|
|||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def validate_paths(cls, values):
|
||||
def validate_paths(self):
|
||||
base_config = get_base_config()
|
||||
|
||||
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
|
||||
if values.vector_db_url and Path(values.vector_db_url).exists():
|
||||
if self.vector_db_url and Path(self.vector_db_url).exists():
|
||||
# Relative path to absolute
|
||||
values.vector_db_url = ensure_absolute_path(
|
||||
values.vector_db_url,
|
||||
self.vector_db_url = ensure_absolute_path(
|
||||
self.vector_db_url,
|
||||
)
|
||||
elif not values.vector_db_url:
|
||||
elif not self.vector_db_url:
|
||||
# Default path
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
self.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ def create_vector_engine(
|
|||
for each provider, raising an EnvironmentError if any are missing, or ImportError if the
|
||||
ChromaDB package is not installed.
|
||||
|
||||
Supported providers include: pgvector, FalkorDB, ChromaDB, and
|
||||
LanceDB.
|
||||
Supported providers include: pgvector, ChromaDB, and LanceDB.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
|
@ -79,18 +78,6 @@ def create_vector_engine(
|
|||
embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=vector_db_url,
|
||||
database_port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "chromadb":
|
||||
try:
|
||||
import chromadb
|
||||
|
|
|
|||
|
|
@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
|
|||
- int: An integer representing the number of dimensions in the embedding vector.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
model: Optional[str] = "openai/text-embedding-3-large",
|
||||
dimensions: Optional[int] = 3072,
|
||||
max_completion_tokens: int = 512,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.batch_size = batch_size
|
||||
# self.retry_count = 0
|
||||
self.embedding_model = TextEmbedding(model_name=model)
|
||||
|
||||
|
|
@ -88,7 +90,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
except Exception as error:
|
||||
logger.error(f"Embedding error in FastembedEmbeddingEngine: {str(error)}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}")
|
||||
raise EmbeddingException(
|
||||
f"Failed to index data points using model {self.model}"
|
||||
) from error
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
@ -101,6 +105,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Instantiate and return the tokenizer used for preparing text for embedding.
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
max_completion_tokens: int = 512,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
self.batch_size = batch_size
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
|
|
@ -148,7 +150,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
litellm.exceptions.NotFoundError,
|
||||
) as e:
|
||||
logger.error(f"Embedding error with model {self.model}: {str(e)}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
|
||||
|
||||
except Exception as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
|
|
@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Load and return the appropriate tokenizer for the specified model based on the provider.
|
||||
|
|
@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
|
||||
# count tokens as we calculate tokens word by word
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=None, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
|
||||
# tokenizer = GeminiTokenizer(
|
||||
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
|
||||
# )
|
||||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
|
|
|
|||
|
|
@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
max_completion_tokens: int = 512,
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||
batch_size: int = 100,
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.endpoint = endpoint
|
||||
self.huggingface_tokenizer_name = huggingface_tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
|
|
@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
return self.dimensions
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Return the desired batch size for embedding calls
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""
|
||||
Load and return a HuggingFace tokenizer for the embedding engine.
|
||||
|
|
|
|||
|
|
@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
|
|||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
embedding_max_completion_tokens: Optional[int] = 8191
|
||||
embedding_batch_size: Optional[int] = None
|
||||
huggingface_tokenizer: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
|
||||
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
|
||||
self.embedding_batch_size = 2048
|
||||
elif not self.embedding_batch_size:
|
||||
self.embedding_batch_size = 100
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Serialize all embedding configuration settings to a dictionary.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
config.embedding_endpoint,
|
||||
config.embedding_api_key,
|
||||
config.embedding_api_version,
|
||||
config.embedding_batch_size,
|
||||
config.huggingface_tokenizer,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_provider,
|
||||
|
|
@ -46,6 +47,7 @@ def create_embedding_engine(
|
|||
embedding_endpoint,
|
||||
embedding_api_key,
|
||||
embedding_api_version,
|
||||
embedding_batch_size,
|
||||
huggingface_tokenizer,
|
||||
llm_api_key,
|
||||
llm_provider,
|
||||
|
|
@ -84,6 +86,7 @@ def create_embedding_engine(
|
|||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
||||
if embedding_provider == "ollama":
|
||||
|
|
@ -95,6 +98,7 @@ def create_embedding_engine(
|
|||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
endpoint=embedding_endpoint,
|
||||
huggingface_tokenizer=huggingface_tokenizer,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
|
|
@ -108,4 +112,5 @@ def create_embedding_engine(
|
|||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
batch_size=embedding_batch_size,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -125,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
data_point_types = get_type_hints(DataPoint)
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
|
||||
class PGVectorDataPoint(Base):
|
||||
"""
|
||||
Represent a point in a vector data space with associated data and vector representation.
|
||||
class PGVectorDataPoint(Base):
|
||||
"""
|
||||
Represent a point in a vector data space with associated data and vector representation.
|
||||
|
||||
This class inherits from Base and is associated with a database table defined by
|
||||
__tablename__. It maintains the following public methods and instance variables:
|
||||
This class inherits from Base and is associated with a database table defined by
|
||||
__tablename__. It maintains the following public methods and instance variables:
|
||||
|
||||
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
||||
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
||||
|
||||
Instance variables:
|
||||
- id: Identifier for the data point, defined by data_point_types.
|
||||
- payload: JSON data associated with the data point.
|
||||
- vector: Vector representation of the data point, with size defined by vector_size.
|
||||
"""
|
||||
Instance variables:
|
||||
- id: Identifier for the data point, defined by data_point_types.
|
||||
- payload: JSON data associated with the data point.
|
||||
- vector: Vector representation of the data point, with size defined by vector_size.
|
||||
"""
|
||||
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
)
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class LLMConfig(BaseSettings):
|
|||
|
||||
structured_output_framework: str = "instructor"
|
||||
llm_provider: str = "openai"
|
||||
llm_model: str = "openai/gpt-4o-mini"
|
||||
llm_model: str = "openai/gpt-5-mini"
|
||||
llm_endpoint: str = ""
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_api_version: Optional[str] = None
|
||||
|
|
@ -48,7 +48,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_max_completion_tokens: int = 16384
|
||||
|
||||
baml_llm_provider: str = "openai"
|
||||
baml_llm_model: str = "gpt-4o-mini"
|
||||
baml_llm_model: str = "gpt-5-mini"
|
||||
baml_llm_endpoint: str = ""
|
||||
baml_llm_api_key: Optional[str] = None
|
||||
baml_llm_temperature: float = 0.0
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|||
- Summarizing large amounts of information
|
||||
- Quick understanding of complex subjects
|
||||
|
||||
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Discovering how entities are connected
|
||||
|
|
@ -95,9 +93,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|||
Query: "Summarize the key findings from these research papers"
|
||||
Response: `SUMMARIES`
|
||||
|
||||
Query: "What is the relationship between the methodologies used in these papers?"
|
||||
Response: `INSIGHTS`
|
||||
|
||||
Query: "When was Einstein born?"
|
||||
Response: `CHUNKS`
|
||||
|
||||
|
|
|
|||
|
|
@ -1,115 +1,155 @@
|
|||
import litellm
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class GeminiAdapter(LLMInterface):
|
||||
"""
|
||||
Handles interactions with a language model API.
|
||||
Adapter for Gemini API LLM provider.
|
||||
|
||||
Public methods include:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
This class initializes the API adapter with necessary credentials and configurations for
|
||||
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
||||
based on user input and system prompts.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
||||
Type[BaseModel]) -> BaseModel
|
||||
"""
|
||||
|
||||
MAX_RETRIES = 5
|
||||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
api_key: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate structured output from the language model based on the provided input and
|
||||
system prompt.
|
||||
Generate a response from a user query.
|
||||
|
||||
This method handles retries and raises a ValueError if the request fails or the response
|
||||
does not conform to the expected schema, logging errors accordingly.
|
||||
This asynchronous method sends a user query and a system prompt to a language model and
|
||||
retrieves the generated response. It handles API communication and retries up to a
|
||||
specified limit in case of request failures.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The user input text to generate a response for.
|
||||
- system_prompt (str): The system's prompt or context to influence the language
|
||||
model's generation.
|
||||
- response_model (Type[BaseModel]): A model type indicating the expected format of
|
||||
the response.
|
||||
- text_input (str): The input text from the user to generate a response for.
|
||||
- system_prompt (str): A prompt that provides context or instructions for the
|
||||
response generation.
|
||||
- response_model (Type[BaseModel]): A Pydantic model that defines the structure of
|
||||
the expected response.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- BaseModel: Returns the generated response as an instance of the specified response
|
||||
model.
|
||||
- BaseModel: An instance of the specified response model containing the structured
|
||||
output from the language model.
|
||||
"""
|
||||
|
||||
try:
|
||||
if response_model is str:
|
||||
response_schema = {"type": "string"}
|
||||
else:
|
||||
response_schema = response_model
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text_input},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=f"{self.model}",
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
temperature=0.1,
|
||||
response_format=response_schema,
|
||||
timeout=100,
|
||||
num_retries=self.MAX_RETRIES,
|
||||
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
if response_model is str:
|
||||
return content
|
||||
return response_model.model_validate_json(content)
|
||||
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Type
|
|||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
|
|
@ -56,9 +56,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
|
||||
)
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
|
|
@ -102,6 +100,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
@ -119,7 +118,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
|
|
@ -152,4 +151,4 @@ class GenericAPIAdapter(LLMInterface):
|
|||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|||
- ANTHROPIC: Represents the Anthropic provider.
|
||||
- CUSTOM: Represents a custom provider option.
|
||||
- GEMINI: Represents the Gemini provider.
|
||||
- MISTRAL: Represents the Mistral AI provider.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
|
@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def get_llm_client(raise_api_key_error: bool = True):
|
||||
|
|
@ -143,7 +145,36 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
streaming=llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class MistralAdapter(LLMInterface):
|
||||
"""
|
||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
"""
|
||||
|
||||
name = "Mistral"
|
||||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from the user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): The input text from the user to be processed.
|
||||
- system_prompt (str): A prompt that sets the context for the query.
|
||||
- response_model (Type[BaseModel]): The model to structure the response according to
|
||||
its format.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
else:
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): Input text from the user to be included in the prompt.
|
||||
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The formatted prompt string combining system prompt and user input.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
|
@ -5,15 +5,13 @@ from typing import Type
|
|||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
MissingSystemPromptPathError,
|
||||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
|
|
@ -148,11 +146,11 @@ class OpenAIAdapter(LLMInterface):
|
|||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
):
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from e
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
|
|
@ -185,7 +183,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
else:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
) from error
|
||||
|
||||
@observe
|
||||
@sleep_and_retry_sync()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List, Any
|
|||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
# NOTE: DEPRECATED as to count tokens you need to send an API request to Google it is too slow to use with Cognee
|
||||
class GeminiTokenizer(TokenizerInterface):
|
||||
"""
|
||||
Implements a tokenizer interface for the Gemini model, managing token extraction and
|
||||
|
|
@ -16,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
llm_model: str,
|
||||
max_completion_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.llm_model = llm_model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
# Get LLM API key from config
|
||||
|
|
@ -28,12 +29,11 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
get_llm_config,
|
||||
)
|
||||
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
|
||||
import google.generativeai as genai
|
||||
from google import genai
|
||||
|
||||
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
|
||||
self.client = genai.Client(api_key=llm_config.llm_api_key)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
"""
|
||||
|
|
@ -77,6 +77,7 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
|
||||
- int: The number of tokens in the given text.
|
||||
"""
|
||||
import google.generativeai as genai
|
||||
|
||||
return len(genai.embed_content(model=f"models/{self.model}", content=text))
|
||||
tokens_response = self.client.models.count_tokens(model=self.llm_model, contents=text)
|
||||
|
||||
return tokens_response.total_tokens
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ class LoaderEngine:
|
|||
|
||||
self.default_loader_priority = [
|
||||
"text_loader",
|
||||
"advanced_pdf_loader",
|
||||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
||||
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
||||
select(Data)
|
||||
.join(Data.datasets)
|
||||
.filter((Dataset.id == dataset_id))
|
||||
.order_by(Data.data_size.desc())
|
||||
)
|
||||
|
||||
data = list(result.scalars().all())
|
||||
|
|
|
|||
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
from uuid import UUID
|
||||
from cognee.cli.exceptions import CliCommandException
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset, Data, DatasetData
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_user
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeletionCountsPreview:
|
||||
datasets: int = 0
|
||||
data_entries: int = 0
|
||||
users: int = 0
|
||||
|
||||
|
||||
async def get_deletion_counts(
|
||||
dataset_name: str = None, user_id: str = None, all_data: bool = False
|
||||
) -> DeletionCountsPreview:
|
||||
"""
|
||||
Calculates the number of items that will be deleted based on the provided arguments.
|
||||
"""
|
||||
counts = DeletionCountsPreview()
|
||||
relational_engine = get_relational_engine()
|
||||
async with relational_engine.get_async_session() as session:
|
||||
if dataset_name:
|
||||
# Find the dataset by name
|
||||
dataset_result = await session.execute(
|
||||
select(Dataset).where(Dataset.name == dataset_name)
|
||||
)
|
||||
dataset = dataset_result.scalar_one_or_none()
|
||||
|
||||
if dataset is None:
|
||||
raise CliCommandException(
|
||||
f"No Dataset exists with the name {dataset_name}", error_code=1
|
||||
)
|
||||
|
||||
# Count data entries linked to this dataset
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(DatasetData)
|
||||
.where(DatasetData.dataset_id == dataset.id)
|
||||
)
|
||||
data_entry_count = (await session.execute(count_query)).scalar_one()
|
||||
counts.users = 1
|
||||
counts.datasets = 1
|
||||
counts.entries = data_entry_count
|
||||
return counts
|
||||
|
||||
elif all_data:
|
||||
# Simplified logic: Get total counts directly from the tables.
|
||||
counts.datasets = (
|
||||
await session.execute(select(func.count()).select_from(Dataset))
|
||||
).scalar_one()
|
||||
counts.entries = (
|
||||
await session.execute(select(func.count()).select_from(Data))
|
||||
).scalar_one()
|
||||
counts.users = (
|
||||
await session.execute(select(func.count()).select_from(User))
|
||||
).scalar_one()
|
||||
return counts
|
||||
|
||||
# Placeholder for user_id logic
|
||||
elif user_id:
|
||||
user = None
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
user = await get_user(user_uuid)
|
||||
except (ValueError, EntityNotFoundError):
|
||||
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
|
||||
counts.users = 1
|
||||
# Find all datasets owned by this user
|
||||
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
|
||||
user_datasets = (await session.execute(datasets_query)).scalars().all()
|
||||
dataset_count = len(user_datasets)
|
||||
counts.datasets = dataset_count
|
||||
if dataset_count > 0:
|
||||
dataset_ids = [d.id for d in user_datasets]
|
||||
# Count all data entries across all of the user's datasets
|
||||
data_count_query = (
|
||||
select(func.count())
|
||||
.select_from(DatasetData)
|
||||
.where(DatasetData.dataset_id.in_(dataset_ids))
|
||||
)
|
||||
data_entry_count = (await session.execute(data_count_query)).scalar_one()
|
||||
counts.entries = data_entry_count
|
||||
else:
|
||||
counts.entries = 0
|
||||
return counts
|
||||
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||
|
||||
class TableRow(DataPoint):
|
||||
name: str
|
||||
is_a: Optional[TableType] = None
|
||||
description: str
|
||||
properties: str
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,31 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.api.v1.exceptions import DatasetNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import (
|
||||
create_authorized_dataset,
|
||||
get_authorized_dataset,
|
||||
get_authorized_dataset_by_name,
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||
resolve_authorized_user_datasets,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_authorized_user_dataset(dataset_id: UUID, dataset_name: str, user: User):
|
||||
async def resolve_authorized_user_dataset(
|
||||
dataset_name: str, dataset_id: Optional[UUID] = None, user: Optional[User] = None
|
||||
):
|
||||
"""
|
||||
Function handles creation and dataset authorization if dataset already exist for Cognee.
|
||||
Verifies that provided user has necessary permission for provided Dataset.
|
||||
If Dataset does not exist creates the Dataset and gives permission for the user creating the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Id of the dataset.
|
||||
dataset_name: Name of the dataset.
|
||||
dataset_id: Id of the dataset.
|
||||
user: Cognee User request is being processed for, if None default user will be used.
|
||||
|
||||
Returns:
|
||||
Tuple[User, Dataset]: A tuple containing the user and the authorized dataset.
|
||||
"""
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
if dataset_id:
|
||||
authorized_dataset = await get_authorized_dataset(user, dataset_id, "write")
|
||||
elif dataset_name:
|
||||
authorized_dataset = await get_authorized_dataset_by_name(dataset_name, user, "write")
|
||||
user, authorized_datasets = await resolve_authorized_user_datasets(
|
||||
datasets=dataset_id if dataset_id else dataset_name, user=user
|
||||
)
|
||||
|
||||
if not authorized_dataset:
|
||||
authorized_dataset = await create_authorized_dataset(
|
||||
dataset_name=dataset_name, user=user
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either dataset_id or dataset_name must be provided.")
|
||||
|
||||
if not authorized_dataset:
|
||||
raise DatasetNotFoundError(
|
||||
message=f"Dataset ({str(dataset_id) or dataset_name}) not found."
|
||||
)
|
||||
|
||||
return user, authorized_dataset
|
||||
return user, authorized_datasets[0]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Tuple, List
|
||||
from typing import Union, Tuple, List, Optional
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -13,7 +13,7 @@ from cognee.modules.data.methods import (
|
|||
|
||||
|
||||
async def resolve_authorized_user_datasets(
|
||||
datasets: Union[str, UUID, list[str], list[UUID]], user: User = None
|
||||
datasets: Union[str, UUID, list[str], list[UUID]], user: Optional[User] = None
|
||||
) -> Tuple[User, List[Dataset]]:
|
||||
"""
|
||||
Function handles creation and dataset authorization if datasets already exist for Cognee.
|
||||
|
|
|
|||
|
|
@ -4,35 +4,28 @@ import asyncio
|
|||
from uuid import UUID
|
||||
from typing import Any, List
|
||||
from functools import wraps
|
||||
from sqlalchemy import select
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
|
||||
from cognee.tasks.ingestion import resolve_data_directories
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunAlreadyCompleted,
|
||||
)
|
||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
from .run_tasks_data_item import run_tasks_data_item
|
||||
from ..tasks.task import Task
|
||||
|
||||
|
||||
|
|
@ -68,176 +61,6 @@ async def run_tasks(
|
|||
context: dict = None,
|
||||
incremental_loading: bool = False,
|
||||
):
|
||||
async def _run_tasks_data_item_incremental(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
db_engine = get_relational_engine()
|
||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||
# If data is being added to Cognee for the first time calculate the id of the data
|
||||
if not isinstance(data_item, Data):
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
# Ingest data and add metadata
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
if data_point:
|
||||
if (
|
||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
):
|
||||
yield {
|
||||
"run_info": PipelineRunAlreadyCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
# Update pipeline status for Data element
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
data_point.pipeline_status[pipeline_name] = {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
}
|
||||
await session.merge(data_point)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||
logger.error(
|
||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||
)
|
||||
yield {
|
||||
"run_info": PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
||||
raise error
|
||||
|
||||
async def _run_tasks_data_item_regular(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
}
|
||||
|
||||
async def _run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
incremental_loading,
|
||||
):
|
||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||
result = None
|
||||
if incremental_loading:
|
||||
async for result in _run_tasks_data_item_incremental(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
async for result in _run_tasks_data_item_regular(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -269,7 +92,7 @@ async def run_tasks(
|
|||
# Create async tasks per data item that will run the pipeline for the data item
|
||||
data_item_tasks = [
|
||||
asyncio.create_task(
|
||||
_run_tasks_data_item(
|
||||
run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
|
|
|
|||
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
Data item processing functions for pipeline operations.
|
||||
|
||||
This module contains reusable functions for processing individual data items
|
||||
within pipeline operations, supporting both incremental and regular processing modes.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, AsyncGenerator, Optional
|
||||
from sqlalchemy import select
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunYield,
|
||||
PipelineRunAlreadyCompleted,
|
||||
)
|
||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||
from cognee.modules.pipelines.operations.run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
from ..tasks.task import Task
|
||||
|
||||
logger = get_logger("run_tasks_data_item")
|
||||
|
||||
|
||||
async def run_tasks_data_item_incremental(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Process a single data item with incremental loading support.
|
||||
|
||||
This function handles incremental processing by checking if the data item
|
||||
has already been processed for the given pipeline and dataset. If it has,
|
||||
it skips processing and returns a completion status.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_name: Name of the pipeline
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
|
||||
Yields:
|
||||
Dict containing run_info and data_id for each processing step
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||
# If data is being added to Cognee for the first time calculate the id of the data
|
||||
if not isinstance(data_item, Data):
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
# Ingest data and add metadata
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
if data_point:
|
||||
if (
|
||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
):
|
||||
yield {
|
||||
"run_info": PipelineRunAlreadyCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
# Update pipeline status for Data element
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
data_point.pipeline_status[pipeline_name] = {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
}
|
||||
await session.merge(data_point)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||
logger.error(
|
||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||
)
|
||||
yield {
|
||||
"run_info": PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
||||
raise error
|
||||
|
||||
|
||||
async def run_tasks_data_item_regular(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Process a single data item in regular (non-incremental) mode.
|
||||
|
||||
This function processes a data item without checking for previous processing
|
||||
status, executing all tasks on the data item.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
|
||||
Yields:
|
||||
Dict containing run_info for each processing step
|
||||
"""
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
async def run_tasks_data_item(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
incremental_loading: bool,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Process a single data item, choosing between incremental and regular processing.
|
||||
|
||||
This is the main entry point for data item processing that delegates to either
|
||||
incremental or regular processing based on the incremental_loading flag.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_name: Name of the pipeline
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
incremental_loading: Whether to use incremental processing
|
||||
|
||||
Returns:
|
||||
Dict containing the final processing result, or None if processing was skipped
|
||||
"""
|
||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||
result = None
|
||||
if incremental_loading:
|
||||
async for result in run_tasks_data_item_incremental(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
async for result in run_tasks_data_item_regular(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
|
@ -3,49 +3,96 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
modal = None
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.pipelines.models import (
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
)
|
||||
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
|
||||
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||
from cognee.tasks.ingestion import resolve_data_directories
|
||||
from .run_tasks_data_item import run_tasks_data_item
|
||||
|
||||
logger = get_logger("run_tasks_distributed()")
|
||||
|
||||
|
||||
if modal:
|
||||
import os
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
|
||||
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||
|
||||
@app.function(
|
||||
retries=3,
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=50,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
secrets=[modal.Secret.from_name(secret_name)],
|
||||
)
|
||||
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
||||
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
|
||||
async def run_tasks_on_modal(
|
||||
data_item,
|
||||
dataset_id: UUID,
|
||||
tasks: List[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[dict],
|
||||
user: User,
|
||||
incremental_loading: bool,
|
||||
):
|
||||
"""
|
||||
Wrapper that runs the run_tasks_data_item function.
|
||||
This is the function/code that runs on modal executor and produces the graph/vector db objects
|
||||
"""
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
run_info = None
|
||||
async with get_relational_engine().get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
run_info = pipeline_run_info
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
return run_info
|
||||
result = await run_tasks_data_item(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
||||
async def run_tasks_distributed(
|
||||
tasks: List[Task],
|
||||
dataset_id: UUID,
|
||||
data: List[Any] = None,
|
||||
user: User = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
incremental_loading: bool = False,
|
||||
):
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
# Get dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
|
@ -53,9 +100,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
|||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
|
|
@ -65,30 +110,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
|||
payload=data,
|
||||
)
|
||||
|
||||
data_count = len(data) if isinstance(data, list) else 1
|
||||
try:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
arguments = [
|
||||
[tasks] * data_count,
|
||||
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
||||
[user] * data_count,
|
||||
[pipeline_name] * data_count,
|
||||
[context] * data_count,
|
||||
]
|
||||
data = await resolve_data_directories(data)
|
||||
|
||||
async for result in run_tasks_on_modal.map.aio(*arguments):
|
||||
logger.info(f"Received result: {result}")
|
||||
number_of_data_items = len(data) if isinstance(data, list) else 1
|
||||
|
||||
yield PipelineRunYield(
|
||||
data_item_tasks = [
|
||||
data,
|
||||
[dataset.id] * number_of_data_items,
|
||||
[tasks] * number_of_data_items,
|
||||
[pipeline_name] * number_of_data_items,
|
||||
[pipeline_id] * number_of_data_items,
|
||||
[pipeline_run_id] * number_of_data_items,
|
||||
[context] * number_of_data_items,
|
||||
[user] * number_of_data_items,
|
||||
[incremental_loading] * number_of_data_items,
|
||||
]
|
||||
|
||||
results = []
|
||||
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
|
||||
if not result:
|
||||
continue
|
||||
results.append(result)
|
||||
|
||||
# Remove skipped results
|
||||
results = [r for r in results if r]
|
||||
|
||||
# If any data item failed, raise PipelineRunFailedError
|
||||
errored = [
|
||||
r
|
||||
for r in results
|
||||
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
|
||||
]
|
||||
if errored:
|
||||
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
|
||||
|
||||
await log_pipeline_run_complete(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||
)
|
||||
|
||||
yield PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
data_ingestion_info=results,
|
||||
)
|
||||
|
||||
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
|
||||
except Exception as error:
|
||||
await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||
)
|
||||
|
||||
yield PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
yield PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
data_ingestion_info=locals().get("results"),
|
||||
)
|
||||
|
||||
if not isinstance(error, PipelineRunFailedError):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
belongs_to_set=interactions_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
|
||||
await add_data_points(data_points=[cognee_user_interaction])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "used_graph_element_to_answer"
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("InsightsRetriever")
|
||||
|
||||
|
||||
class InsightsRetriever(BaseGraphRetriever):
|
||||
"""
|
||||
Retriever for handling graph connection-based insights.
|
||||
|
||||
Public methods include:
|
||||
- get_context
|
||||
- get_completion
|
||||
|
||||
Instance variables include:
|
||||
- exploration_levels
|
||||
- top_k
|
||||
"""
|
||||
|
||||
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
|
||||
"""Initialize retriever with exploration levels and search parameters."""
|
||||
self.exploration_levels = exploration_levels
|
||||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> list:
|
||||
"""
|
||||
Find neighbours of a given node in the graph.
|
||||
|
||||
If the provided query does not correspond to an existing node,
|
||||
search for similar entities and retrieve their connections.
|
||||
Reraises NoDataError if there is no data found in the system.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): A string identifier for the node whose neighbours are to be
|
||||
retrieved.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- list: A list of unique connections found for the queried node.
|
||||
"""
|
||||
if query is None:
|
||||
return []
|
||||
|
||||
node_id = query
|
||||
graph_engine = await get_graph_engine()
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
|
||||
if exact_node is not None and "id" in exact_node:
|
||||
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
||||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
||||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("Entity collections not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
||||
|
||||
if len(relevant_results) == 0:
|
||||
return []
|
||||
|
||||
node_connections_results = await asyncio.gather(
|
||||
*[graph_engine.get_connections(result.id) for result in relevant_results]
|
||||
)
|
||||
|
||||
node_connections = []
|
||||
for neighbours in node_connections_results:
|
||||
node_connections.extend(neighbours)
|
||||
|
||||
unique_node_connections_map = {}
|
||||
unique_node_connections = []
|
||||
|
||||
for node_connection in node_connections:
|
||||
if "id" not in node_connection[0] or "id" not in node_connection[2]:
|
||||
continue
|
||||
|
||||
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
|
||||
if unique_id not in unique_node_connections_map:
|
||||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return unique_node_connections
|
||||
# return [
|
||||
# Edge(
|
||||
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
# attributes={
|
||||
# **connection[1],
|
||||
# "relationship_type": connection[1]["relationship_name"],
|
||||
# },
|
||||
# )
|
||||
# for connection in unique_node_connections
|
||||
# ]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
If a context is not provided, it fetches the context using the query provided.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): A string identifier used to fetch the context.
|
||||
- context (Optional[Any]): An optional context to use for the completion; if None,
|
||||
it fetches the context based on the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context used for the completion, which is either provided or fetched
|
||||
based on the query.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
||||
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage import add_data_points, index_graph_edges
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
|
|||
belongs_to_set=feedbacks_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
|
||||
await add_data_points(data_points=[cognee_user_feedback])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "gives_feedback_to"
|
||||
|
|
@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
|
|||
if len(relationships) > 0:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
await index_graph_edges(relationships)
|
||||
await graph_engine.apply_feedback_weight(
|
||||
node_ids=to_node_ids, weight=feedback_sentiment.score
|
||||
)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ async def code_description_to_code_part(
|
|||
|
||||
try:
|
||||
if include_docs:
|
||||
search_results = await search(query_text=query, query_type="INSIGHTS")
|
||||
search_results = await search(query_text=query, query_type="GRAPH_COMPLETION")
|
||||
|
||||
concatenated_descriptions = " ".join(
|
||||
obj["description"]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|||
# Retrievers
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -44,10 +43,6 @@ async def get_search_type_tools(
|
|||
SummariesRetriever(top_k=top_k).get_completion,
|
||||
SummariesRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.INSIGHTS: [
|
||||
InsightsRetriever(top_k=top_k).get_completion,
|
||||
InsightsRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CHUNKS: [
|
||||
ChunksRetriever(top_k=top_k).get_completion,
|
||||
ChunksRetriever(top_k=top_k).get_context,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ from cognee.modules.search.types import (
|
|||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.data.methods.get_authorized_existing_datasets import (
|
||||
get_authorized_existing_datasets,
|
||||
)
|
||||
|
||||
from .get_search_type_tools import get_search_type_tools
|
||||
from .no_access_control_search import no_access_control_search
|
||||
|
|
@ -202,7 +204,9 @@ async def authorized_search(
|
|||
Not to be used outside of active access control mode.
|
||||
"""
|
||||
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
|
||||
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
|
||||
search_datasets = await get_authorized_existing_datasets(
|
||||
datasets=dataset_ids, permission_type="read", user=user
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
search_responses = await search_in_datasets_context(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from enum import Enum
|
|||
|
||||
class SearchType(Enum):
|
||||
SUMMARIES = "SUMMARIES"
|
||||
INSIGHTS = "INSIGHTS"
|
||||
CHUNKS = "CHUNKS"
|
||||
RAG_COMPLETION = "RAG_COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class ModelName(Enum):
|
|||
ollama = "ollama"
|
||||
anthropic = "anthropic"
|
||||
gemini = "gemini"
|
||||
mistral = "mistral"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
|
|
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
|
|||
"value": "gemini",
|
||||
"label": "Gemini",
|
||||
},
|
||||
{
|
||||
"value": "mistral",
|
||||
"label": "Mistral",
|
||||
},
|
||||
]
|
||||
|
||||
return SettingsDict.model_validate(
|
||||
|
|
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
|
|||
"label": "Gemini 2.0 Flash",
|
||||
},
|
||||
],
|
||||
"mistral": [
|
||||
{
|
||||
"value": "mistral-medium-2508",
|
||||
"label": "Mistral Medium 3.1",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2509",
|
||||
"label": "Magistral Medium 1.2",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2507",
|
||||
"label": "Magistral Medium 1.1",
|
||||
},
|
||||
{
|
||||
"value": "mistral-large-2411",
|
||||
"label": "Mistral Large 2.1",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
vector_db={
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ async def get_authenticated_user(
|
|||
except Exception as e:
|
||||
# Convert any get_default_user failure into a proper HTTP 500 error
|
||||
logger.error(f"Failed to create default user: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create default user: {str(e)}"
|
||||
) from e
|
||||
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ async def create_role(
|
|||
# Add association directly to the association table
|
||||
role = Role(name=role_name, tenant_id=tenant.id)
|
||||
session.add(role)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Role already exists for tenant.")
|
||||
except IntegrityError as e:
|
||||
raise EntityAlreadyExistsError(message="Role already exists for tenant.") from e
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
|
|
|||
|
|
@ -35,5 +35,5 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
|
|||
await session.merge(user)
|
||||
await session.commit()
|
||||
return tenant.id
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Tenant already exists.")
|
||||
except IntegrityError as e:
|
||||
raise EntityAlreadyExistsError(message="Tenant already exists.") from e
|
||||
|
|
|
|||
|
|
@ -288,7 +288,6 @@ class SummarizedCode(BaseModel):
|
|||
class GraphDBType(Enum):
|
||||
NETWORKX = auto()
|
||||
NEO4J = auto()
|
||||
FALKORDB = auto()
|
||||
KUZU = auto()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -124,5 +124,4 @@ async def add_rule_associations(
|
|||
|
||||
if len(edges_to_save) > 0:
|
||||
await graph_engine.add_edges(edges_to_save)
|
||||
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges_to_save)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||
from cognee.tasks.storage import index_graph_edges
|
||||
from cognee.tasks.storage.add_data_points import add_data_points
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||
|
|
@ -88,6 +89,7 @@ async def integrate_chunk_graphs(
|
|||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
await index_graph_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.modules.ingestion import save_data_to_file
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
|
@ -17,6 +18,13 @@ class SaveDataSettings(BaseSettings):
|
|||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
class HTMLContent(str):
|
||||
def __new__(cls, value: str):
|
||||
if not ("<" in value and ">" in value):
|
||||
raise ValueError("Not valid HTML-like content")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
settings = SaveDataSettings()
|
||||
|
||||
|
||||
|
|
@ -27,6 +35,12 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|||
|
||||
return await get_data_from_llama_index(data_item)
|
||||
|
||||
if "docling" in str(type(data_item)):
|
||||
from docling_core.types import DoclingDocument
|
||||
|
||||
if isinstance(data_item, DoclingDocument):
|
||||
data_item = data_item.export_to_text()
|
||||
|
||||
# data is a file object coming from upload.
|
||||
if hasattr(data_item, "file"):
|
||||
return await save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
|
|
@ -48,6 +62,40 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|||
# data is s3 file path
|
||||
if parsed_url.scheme == "s3":
|
||||
return data_item
|
||||
elif parsed_url.scheme == "http" or parsed_url.scheme == "https":
|
||||
# Validate URL by sending a HEAD request
|
||||
try:
|
||||
from cognee.context_global_variables import tavily_config, soup_crawler_config
|
||||
from cognee.tasks.web_scraper import fetch_page_content
|
||||
|
||||
tavily = tavily_config.get()
|
||||
soup_crawler = soup_crawler_config.get()
|
||||
preferred_tool = "beautifulsoup" if soup_crawler else "tavily"
|
||||
if preferred_tool == "tavily" and tavily is None:
|
||||
raise IngestionError(
|
||||
message="TavilyConfig must be set on the ingestion context when fetching HTTP URLs without a SoupCrawlerConfig."
|
||||
)
|
||||
if preferred_tool == "beautifulsoup" and soup_crawler is None:
|
||||
raise IngestionError(
|
||||
message="SoupCrawlerConfig must be set on the ingestion context when using the BeautifulSoup scraper."
|
||||
)
|
||||
|
||||
data = await fetch_page_content(
|
||||
data_item,
|
||||
preferred_tool=preferred_tool,
|
||||
tavily_config=tavily,
|
||||
soup_crawler_config=soup_crawler,
|
||||
)
|
||||
content = ""
|
||||
for key, value in data.items():
|
||||
content += f"{key}:\n{value}\n\n"
|
||||
return await save_data_to_file(content)
|
||||
except IngestionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise IngestionError(
|
||||
message=f"Error ingesting webpage results of url {data_item}: {str(e)}"
|
||||
)
|
||||
|
||||
# data is local file path
|
||||
elif parsed_url.scheme == "file":
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ from cognee.tasks.storage.exceptions import (
|
|||
)
|
||||
|
||||
|
||||
async def add_data_points(
|
||||
data_points: List[DataPoint], update_edge_collection: bool = True
|
||||
) -> List[DataPoint]:
|
||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""
|
||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||
deduplicating them, and indexing them for retrieval.
|
||||
|
|
@ -25,9 +23,6 @@ async def add_data_points(
|
|||
Args:
|
||||
data_points (List[DataPoint]):
|
||||
A list of data points to process and insert into the graph.
|
||||
update_edge_collection (bool, optional):
|
||||
Whether to update the edge index after adding edges.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
List[DataPoint]:
|
||||
|
|
@ -73,12 +68,10 @@ async def add_data_points(
|
|||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await index_data_points(nodes)
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
if update_edge_collection:
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges)
|
||||
|
||||
return data_points
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue