Merge branch 'dev' into feature/cog-2985-add-ci-tests-that-run-more-examples

This commit is contained in:
Andrej Milicevic 2025-10-16 11:02:48 +02:00
commit 6e3370399b
154 changed files with 8669 additions and 19135 deletions

View file

@ -16,7 +16,7 @@
STRUCTURED_OUTPUT_FRAMEWORK="instructor"
LLM_API_KEY="your_api_key"
LLM_MODEL="openai/gpt-4o-mini"
LLM_MODEL="openai/gpt-5-mini"
LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
@ -30,10 +30,13 @@ EMBEDDING_DIMENSIONS=3072
EMBEDDING_MAX_TOKENS=8191
# If embedding key is not provided same key set for LLM_API_KEY will be used
#EMBEDDING_API_KEY="your_api_key"
# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch,
# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models
#EMBEDDING_BATCH_SIZE=2048
# If using BAML structured output these env variables will be used
BAML_LLM_PROVIDER=openai
BAML_LLM_MODEL="gpt-4o-mini"
BAML_LLM_MODEL="gpt-5-mini"
BAML_LLM_ENDPOINT=""
BAML_LLM_API_KEY="your_api_key"
BAML_LLM_API_VERSION=""
@ -52,18 +55,18 @@ BAML_LLM_API_VERSION=""
################################################################################
# Configure storage backend (local filesystem or S3)
# STORAGE_BACKEND="local" # Default: uses local filesystem
#
#
# -- To switch to S3 storage, uncomment and fill these: ---------------------
# STORAGE_BACKEND="s3"
# STORAGE_BUCKET_NAME="your-bucket-name"
# AWS_REGION="us-east-1"
# AWS_ACCESS_KEY_ID="your-access-key"
# AWS_SECRET_ACCESS_KEY="your-secret-key"
#
#
# -- S3 Root Directories (optional) -----------------------------------------
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
#
#
# -- Cache Directory (auto-configured for S3) -------------------------------
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
# To override the automatic S3 cache location, uncomment:

View file

@ -58,7 +58,7 @@ body:
- Python version: [e.g. 3.9.0]
- Cognee version: [e.g. 0.1.0]
- LLM Provider: [e.g. OpenAI, Ollama]
- Database: [e.g. Neo4j, FalkorDB]
- Database: [e.g. Neo4j]
validations:
required: true

View file

@ -41,4 +41,4 @@ runs:
EXTRA_ARGS="$EXTRA_ARGS --extra $extra"
done
fi
uv sync --extra api --extra docs --extra evals --extra gemini --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS

View file

@ -54,6 +54,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run Neo4j Example
env:
ENV: dev
@ -66,9 +70,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: |
uv run python examples/database_examples/neo4j_example.py

73
.github/workflows/distributed_test.yml vendored Normal file
View file

@ -0,0 +1,73 @@
name: Distributed Cognee test with modal
permissions:
contents: read
on:
workflow_call:
inputs:
python-version:
required: false
type: string
default: '3.11.x'
secrets:
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
jobs:
run-server-start-test:
name: Distributed Cognee test (Modal)
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "distributed postgres"
- name: Run Distributed Cognee (Modal)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
DB_PROVIDER: "postgres"
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
VECTOR_DB_PROVIDER: "pgvector"
COGNEE_DISTRIBUTED: "true"
run: uv run modal run ./distributed/entrypoint.py

View file

@ -234,3 +234,28 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/permissions_example.py
test_docling_add:
name: Run Add with Docling Test
runs-on: macos-15
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: 'docling'
- name: Run Docling Test
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_add_docling_document.py

View file

@ -71,6 +71,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run default Neo4j
env:
ENV: 'dev'
@ -83,9 +87,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_neo4j.py
- name: Run Weighted Edges Tests with Neo4j

View file

@ -186,6 +186,10 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Install specific db dependency
run: echo "Dependencies already installed in setup"
@ -206,9 +210,9 @@ jobs:
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}

View file

@ -51,20 +51,6 @@ jobs:
name: Search test for Neo4j/LanceDB/Sqlite
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_search_db.py
run-kuzu-pgvector-postgres-search-tests:
@ -158,19 +148,6 @@ jobs:
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
postgres:
image: pgvector/pgvector:pg17
env:
@ -196,6 +173,10 @@ jobs:
python-version: ${{ inputs.python-version }}
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -213,9 +194,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'pgvector'
DB_PROVIDER: 'postgres'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -51,20 +51,6 @@ jobs:
name: Temporal Graph test Neo4j (lancedb + sqlite)
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_temporal_graph.py
run_temporal_graph_kuzu_postgres_pgvector:

View file

@ -27,7 +27,7 @@ jobs:
env:
LLM_PROVIDER: "gemini"
LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }}
LLM_MODEL: "gemini/gemini-1.5-flash"
LLM_MODEL: "gemini/gemini-2.0-flash"
EMBEDDING_PROVIDER: "gemini"
EMBEDDING_API_KEY: ${{ secrets.GEMINI_API_KEY }}
EMBEDDING_MODEL: "gemini/text-embedding-004"
@ -83,4 +83,4 @@ jobs:
EMBEDDING_MODEL: "openai/text-embedding-3-large"
EMBEDDING_DIMENSIONS: "3072"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
run: uv run python ./examples/python/simple_example.py

View file

@ -6,8 +6,12 @@ on:
permissions:
contents: read
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
test-gemini:
test-s3-storage:
name: Run S3 File Storage Test
runs-on: ubuntu-22.04
steps:

View file

@ -27,6 +27,12 @@ jobs:
uses: ./.github/workflows/e2e_tests.yml
secrets: inherit
distributed-tests:
name: Distributed Cognee Test
needs: [ basic-tests, e2e-tests, graph-db-tests ]
uses: ./.github/workflows/distributed_test.yml
secrets: inherit
cli-tests:
name: CLI Tests
uses: ./.github/workflows/cli_tests.yml
@ -104,7 +110,7 @@ jobs:
db-examples-tests:
name: DB Examples Tests
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
uses: ./.github/workflows/db_examples_tests.yml
secrets: inherit

View file

@ -86,12 +86,19 @@ jobs:
with:
python-version: '3.11'
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Weighted Edges Tests
env:
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
run: |
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short

View file

@ -22,6 +22,7 @@ RUN apt-get update && apt-get install -y \
libpq-dev \
git \
curl \
cmake \
clang \
build-essential \
&& rm -rf /var/lib/apt/lists/*
@ -31,7 +32,7 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
# Install the project's dependencies using the lockfile and settings
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
# Copy Alembic configuration
COPY alembic.ini /app/alembic.ini
@ -42,7 +43,7 @@ COPY alembic/ /app/alembic
COPY ./cognee /app/cognee
COPY ./distributed /app/distributed
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
FROM python:3.12-slim-bookworm

View file

@ -76,6 +76,9 @@ Get started quickly with a Google Colab <a href="https://colab.research.google.
## About cognee
cognee works locally and stores your data on your device.
Our hosted solution is just our deployment of OSS cognee on Modal, with the goal of making development and productionization easier.
Self-hosted package:
- Interconnects any kind of documents: past conversations, files, images, and audio transcriptions

View file

@ -217,10 +217,24 @@ export default function GraphVisualization({ ref, data, graphControls, className
const [graphShape, setGraphShape] = useState<string>();
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
durationMs?: number,
padding?: number,
nodeFilter?: (node: NodeObject) => boolean
) => {
if (!graphRef.current) {
console.warn("GraphVisualization: graphRef not ready yet");
return undefined as any;
}
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
};
useImperativeHandle(ref, () => ({
zoomToFit: graphRef.current!.zoomToFit,
setGraphShape: setGraphShape,
zoomToFit,
setGraphShape,
}));
return (
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">

View file

@ -89,15 +89,6 @@ export default function useChat(dataset: Dataset) {
}
interface Node {
name: string;
}
interface Relationship {
relationship_name: string;
}
type InsightMessage = [Node, Relationship, Node];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: string): string {
@ -106,14 +97,6 @@ function convertToSearchTypeOutput(systemMessage: any[] | any, searchType: strin
}
switch (searchType) {
case "INSIGHTS":
return systemMessage.map((message: InsightMessage) => {
const [node1, relationship, node2] = message;
if (node1.name && node2.name) {
return `${node1.name} ${relationship.relationship_name} ${node2.name}.`;
}
return "";
}).join("\n");
case "SUMMARIES":
return systemMessage.map((message: { text: string }) => message.text).join("\n");
case "CHUNKS":

View file

@ -65,6 +65,9 @@ ENV PYTHONUNBUFFERED=1
ENV MCP_LOG_LEVEL=DEBUG
ENV PYTHONPATH=/app
# Add labels for API mode usage
LABEL org.opencontainers.image.description="Cognee MCP Server with API mode support"
# Use the application name from pyproject.toml for normal operation
# For testing, we'll override this with a direct command
ENTRYPOINT ["/app/entrypoint.sh"]

View file

@ -38,7 +38,8 @@ Build memory for Agents and query from any client that speaks MCP  in your t
## ✨ Features
- Multiple transports choose Streamable HTTP --transport http (recommended for web deployments), SSE --transport sse (realtime streaming), or stdio (classic pipe, default)
- Integrated logging all actions written to a rotating file (see get_log_file_location()) and mirrored to console in dev
- **API Mode** connect to an already running Cognee FastAPI server instead of using cognee directly (see [API Mode](#-api-mode) below)
- Integrated logging all actions written to a rotating file (see get_log_file_location()) and mirrored to console in dev
- Local file ingestion feed .md, source files, Cursor rulesets, etc. straight from disk
- Background pipelines longrunning cognify & codify jobs spawn offthread; check progress with status tools
- Developer rules bootstrap one call indexes .cursorrules, .cursor/rules, AGENT.md, and friends into the developer_rules nodeset
@ -91,7 +92,7 @@ To use different LLM providers / database configurations, and for more info chec
## 🐳 Docker Usage
If youd rather run cognee-mcp in a container, you have two options:
If you'd rather run cognee-mcp in a container, you have two options:
1. **Build locally**
1. Make sure you are in /cognee root directory and have a fresh `.env` containing only your `LLM_API_KEY` (and your chosen settings).
@ -128,6 +129,64 @@ If youd rather run cognee-mcp in a container, you have two options:
- ✅ Direct: `python src/server.py --transport http`
- ❌ Direct: `-e TRANSPORT_MODE=http` (won't work)
### **Docker API Mode**
To connect the MCP Docker container to a Cognee API server running on your host machine:
#### **Simple Usage (Automatic localhost handling):**
```bash
# Start your Cognee API server on the host
python -m cognee.api.client
# Run MCP container in API mode - localhost is automatically converted!
docker run \
-e TRANSPORT_MODE=sse \
-e API_URL=http://localhost:8000 \
-e API_TOKEN=your_auth_token \
-p 8001:8000 \
--rm -it cognee/cognee-mcp:main
```
**Note:** The container will automatically convert `localhost` to `host.docker.internal` on Mac/Windows/Docker Desktop. You'll see a message in the logs showing the conversion.
#### **Explicit host.docker.internal (Mac/Windows):**
```bash
# Or explicitly use host.docker.internal
docker run \
-e TRANSPORT_MODE=sse \
-e API_URL=http://host.docker.internal:8000 \
-e API_TOKEN=your_auth_token \
-p 8001:8000 \
--rm -it cognee/cognee-mcp:main
```
#### **On Linux (use host network or container IP):**
```bash
# Option 1: Use host network (simplest)
docker run \
--network host \
-e TRANSPORT_MODE=sse \
-e API_URL=http://localhost:8000 \
-e API_TOKEN=your_auth_token \
--rm -it cognee/cognee-mcp:main
# Option 2: Use host IP address
# First, get your host IP: ip addr show docker0
docker run \
-e TRANSPORT_MODE=sse \
-e API_URL=http://172.17.0.1:8000 \
-e API_TOKEN=your_auth_token \
-p 8001:8000 \
--rm -it cognee/cognee-mcp:main
```
**Environment variables for API mode:**
- `API_URL`: URL of the running Cognee API server
- `API_TOKEN`: Authentication token (optional, required if API has authentication enabled)
**Note:** When running in API mode:
- Database migrations are automatically skipped (API server handles its own DB)
- Some features are limited (see [API Mode Limitations](#-api-mode))
## 🔗 MCP Client Configuration
@ -255,6 +314,76 @@ You can configure both transports simultaneously for testing:
**Note:** Only enable the server you're actually running to avoid connection errors.
## 🌐 API Mode
The MCP server can operate in two modes:
### **Direct Mode** (Default)
The MCP server directly imports and uses the cognee library. This is the default mode with full feature support.
### **API Mode**
The MCP server connects to an already running Cognee FastAPI server via HTTP requests. This is useful when:
- You have a centralized Cognee API server running
- You want to separate the MCP server from the knowledge graph backend
- You need multiple MCP servers to share the same knowledge graph
**Starting the MCP server in API mode:**
```bash
# Start your Cognee FastAPI server first (default port 8000)
cd /path/to/cognee
python -m cognee.api.client
# Then start the MCP server in API mode
cd cognee-mcp
python src/server.py --api-url http://localhost:8000 --api-token YOUR_AUTH_TOKEN
```
**API Mode with different transports:**
```bash
# With SSE transport
python src/server.py --transport sse --api-url http://localhost:8000 --api-token YOUR_TOKEN
# With HTTP transport
python src/server.py --transport http --api-url http://localhost:8000 --api-token YOUR_TOKEN
```
**API Mode with Docker:**
```bash
# On Mac/Windows (use host.docker.internal to access host)
docker run \
-e TRANSPORT_MODE=sse \
-e API_URL=http://host.docker.internal:8000 \
-e API_TOKEN=YOUR_TOKEN \
-p 8001:8000 \
--rm -it cognee/cognee-mcp:main
# On Linux (use host network)
docker run \
--network host \
-e TRANSPORT_MODE=sse \
-e API_URL=http://localhost:8000 \
-e API_TOKEN=YOUR_TOKEN \
--rm -it cognee/cognee-mcp:main
```
**Command-line arguments for API mode:**
- `--api-url`: Base URL of the running Cognee FastAPI server (e.g., `http://localhost:8000`)
- `--api-token`: Authentication token for the API (optional, required if API has authentication enabled)
**Docker environment variables for API mode:**
- `API_URL`: Base URL of the running Cognee FastAPI server
- `API_TOKEN`: Authentication token (optional, required if API has authentication enabled)
**API Mode limitations:**
Some features are only available in direct mode:
- `codify` (code graph pipeline)
- `cognify_status` / `codify_status` (pipeline status tracking)
- `prune` (data reset)
- `get_developer_rules` (developer rules retrieval)
- `list_data` with specific dataset_id (detailed data listing)
Basic operations like `cognify`, `search`, `delete`, and `list_data` (all datasets) work in both modes.
## 💻 Basic Usage
The MCP server exposes its functionality through tools. Call them from any MCP client (Cursor, Claude Desktop, Cline, Roo and more).
@ -266,7 +395,7 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
- **codify**: Analyse a code repository, build a code graph, stores it in memory
- **search**: Query memory supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, INSIGHTS
- **search**: Query memory supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
- **list_data**: List all datasets and their data items with IDs for deletion operations

View file

@ -14,61 +14,94 @@ HTTP_PORT=${HTTP_PORT:-8000}
echo "Debug port: $DEBUG_PORT"
echo "HTTP port: $HTTP_PORT"
# Run Alembic migrations with proper error handling.
# Note on UserAlreadyExists error handling:
# During database migrations, we attempt to create a default user. If this user
# already exists (e.g., from a previous deployment or migration), it's not a
# critical error and shouldn't prevent the application from starting. This is
# different from other migration errors which could indicate database schema
# inconsistencies and should cause the startup to fail. This check allows for
# smooth redeployments and container restarts while maintaining data integrity.
echo "Running database migrations..."
# Check if API mode is enabled
if [ -n "$API_URL" ]; then
echo "API mode enabled: $API_URL"
echo "Skipping database migrations (API server handles its own database)"
else
echo "Direct mode: Using local cognee instance"
# Run Alembic migrations with proper error handling.
# Note on UserAlreadyExists error handling:
# During database migrations, we attempt to create a default user. If this user
# already exists (e.g., from a previous deployment or migration), it's not a
# critical error and shouldn't prevent the application from starting. This is
# different from other migration errors which could indicate database schema
# inconsistencies and should cause the startup to fail. This check allows for
# smooth redeployments and container restarts while maintaining data integrity.
echo "Running database migrations..."
MIGRATION_OUTPUT=$(alembic upgrade head)
MIGRATION_EXIT_CODE=$?
MIGRATION_OUTPUT=$(alembic upgrade head)
MIGRATION_EXIT_CODE=$?
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
echo "Warning: Default user already exists, continuing startup..."
else
echo "Migration failed with unexpected error."
exit 1
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
echo "Warning: Default user already exists, continuing startup..."
else
echo "Migration failed with unexpected error."
exit 1
fi
fi
fi
echo "Database migrations done."
echo "Database migrations done."
fi
echo "Starting Cognee MCP Server with transport mode: $TRANSPORT_MODE"
# Add startup delay to ensure DB is ready
sleep 2
# Build API arguments if API_URL is set
API_ARGS=""
if [ -n "$API_URL" ]; then
# Handle localhost in API_URL - convert to host-accessible address
if echo "$API_URL" | grep -q "localhost" || echo "$API_URL" | grep -q "127.0.0.1"; then
echo "⚠️ Warning: API_URL contains localhost/127.0.0.1"
echo " Original: $API_URL"
# Try to use host.docker.internal (works on Mac/Windows and recent Linux with Docker Desktop)
FIXED_API_URL=$(echo "$API_URL" | sed 's/localhost/host.docker.internal/g' | sed 's/127\.0\.0\.1/host.docker.internal/g')
echo " Converted to: $FIXED_API_URL"
echo " This will work on Mac/Windows/Docker Desktop."
echo " On Linux without Docker Desktop, you may need to:"
echo " - Use --network host, OR"
echo " - Set API_URL=http://172.17.0.1:8000 (Docker bridge IP)"
API_URL="$FIXED_API_URL"
fi
API_ARGS="--api-url $API_URL"
if [ -n "$API_TOKEN" ]; then
API_ARGS="$API_ARGS --api-token $API_TOKEN"
fi
fi
# Modified startup with transport mode selection and error handling
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
if [ "$DEBUG" = "true" ]; then
echo "Waiting for the debugger to attach..."
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
else
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration $API_ARGS
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
else
exec cognee-mcp --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration $API_ARGS
fi
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration $API_ARGS
else
exec cognee-mcp --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration $API_ARGS
fi
fi

View file

@ -8,10 +8,12 @@ requires-python = ">=3.10"
dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
"fastmcp>=2.10.0,<3.0.0",
"mcp>=1.12.0,<2.0.0",
"uv>=0.6.3,<1.0.0",
"httpx>=0.27.0,<1.0.0",
]
authors = [
@ -36,4 +38,5 @@ dev = [
allow-direct-references = true
[project.scripts]
cognee-mcp = "src:main"
cognee = "src:main"
cognee-mcp = "src:main_mcp"

View file

@ -1,8 +1,60 @@
from .server import main as server_main
try:
from .server import main as server_main
except ImportError:
from server import main as server_main
import warnings
import sys
def main():
"""Main entry point for the package."""
"""Deprecated main entry point for the package."""
import asyncio
deprecation_notice = """
DEPRECATION NOTICE
The CLI entry-point used to start the Cognee MCP service has been renamed from
"cognee" to "cognee-mcp". Calling the old entry-point will stop working in a
future release.
WHAT YOU NEED TO DO:
Locate every place where you launch the MCP process and replace the final
argument cognee cognee-mcp.
For the example mcpServers block from Cursor shown below the change is:
{
"mcpServers": {
"Cognee": {
"command": "uv",
"args": [
"--directory",
"/path/to/cognee-mcp",
"run",
"cognee" // <-- CHANGE THIS to "cognee-mcp"
]
}
}
}
Continuing to use the old "cognee" entry-point will result in failures once it
is removed, so please update your configuration and any shell scripts as soon
as possible.
"""
warnings.warn(
"The 'cognee' command for cognee-mcp is deprecated and will be removed in a future version. "
"Please use 'cognee-mcp' instead to avoid conflicts with the main cognee library.",
DeprecationWarning,
stacklevel=2,
)
print("⚠️ DEPRECATION WARNING", file=sys.stderr)
print(deprecation_notice, file=sys.stderr)
asyncio.run(server_main())
def main_mcp():
"""Clean main entry point for cognee-mcp command."""
import asyncio
asyncio.run(server_main())

View file

@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
await index_graph_edges()
await index_graph_edges(edges_to_save)

View file

@ -0,0 +1,338 @@
"""
Cognee Client abstraction that supports both direct function calls and HTTP API calls.
This module provides a unified interface for interacting with Cognee, supporting:
- Direct mode: Directly imports and calls cognee functions (default behavior)
- API mode: Makes HTTP requests to a running Cognee FastAPI server
"""
import sys
from typing import Optional, Any, List, Dict
from uuid import UUID
from contextlib import redirect_stdout
import httpx
from cognee.shared.logging_utils import get_logger
import json
logger = get_logger()
class CogneeClient:
"""
Unified client for interacting with Cognee via direct calls or HTTP API.
Parameters
----------
api_url : str, optional
Base URL of the Cognee API server (e.g., "http://localhost:8000").
If None, uses direct cognee function calls.
api_token : str, optional
Authentication token for the API (optional, required if API has authentication enabled).
"""
def __init__(self, api_url: Optional[str] = None, api_token: Optional[str] = None):
self.api_url = api_url.rstrip("/") if api_url else None
self.api_token = api_token
self.use_api = bool(api_url)
if self.use_api:
logger.info(f"Cognee client initialized in API mode: {self.api_url}")
self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long operations
else:
logger.info("Cognee client initialized in direct mode")
# Import cognee only if we're using direct mode
import cognee as _cognee
self.cognee = _cognee
def _get_headers(self) -> Dict[str, str]:
"""Get headers for API requests."""
headers = {"Content-Type": "application/json"}
if self.api_token:
headers["Authorization"] = f"Bearer {self.api_token}"
return headers
async def add(
self, data: Any, dataset_name: str = "main_dataset", node_set: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Add data to Cognee for processing.
Parameters
----------
data : Any
Data to add (text, file path, etc.)
dataset_name : str
Name of the dataset to add data to
node_set : List[str], optional
List of node identifiers for graph organization
Returns
-------
Dict[str, Any]
Result of the add operation
"""
if self.use_api:
endpoint = f"{self.api_url}/api/v1/add"
files = {"data": ("data.txt", str(data), "text/plain")}
form_data = {
"datasetName": dataset_name,
}
if node_set is not None:
form_data["node_set"] = json.dumps(node_set)
response = await self.client.post(
endpoint,
files=files,
data=form_data,
headers={"Authorization": f"Bearer {self.api_token}"} if self.api_token else {},
)
response.raise_for_status()
return response.json()
else:
with redirect_stdout(sys.stderr):
await self.cognee.add(data, dataset_name=dataset_name, node_set=node_set)
return {"status": "success", "message": "Data added successfully"}
async def cognify(
self,
datasets: Optional[List[str]] = None,
custom_prompt: Optional[str] = None,
graph_model: Any = None,
) -> Dict[str, Any]:
"""
Transform data into a knowledge graph.
Parameters
----------
datasets : List[str], optional
List of dataset names to process
custom_prompt : str, optional
Custom prompt for entity extraction
graph_model : Any, optional
Custom graph model (only used in direct mode)
Returns
-------
Dict[str, Any]
Result of the cognify operation
"""
if self.use_api:
# API mode: Make HTTP request
endpoint = f"{self.api_url}/api/v1/cognify"
payload = {
"datasets": datasets or ["main_dataset"],
"run_in_background": False,
}
if custom_prompt:
payload["custom_prompt"] = custom_prompt
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
response.raise_for_status()
return response.json()
else:
# Direct mode: Call cognee directly
with redirect_stdout(sys.stderr):
kwargs = {}
if datasets:
kwargs["datasets"] = datasets
if custom_prompt:
kwargs["custom_prompt"] = custom_prompt
if graph_model:
kwargs["graph_model"] = graph_model
await self.cognee.cognify(**kwargs)
return {"status": "success", "message": "Cognify completed successfully"}
async def search(
self,
query_text: str,
query_type: str,
datasets: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
top_k: int = 10,
) -> Any:
"""
Search the knowledge graph.
Parameters
----------
query_text : str
The search query
query_type : str
Type of search (e.g., "GRAPH_COMPLETION", "INSIGHTS", etc.)
datasets : List[str], optional
List of datasets to search
system_prompt : str, optional
System prompt for completion searches
top_k : int
Maximum number of results
Returns
-------
Any
Search results
"""
if self.use_api:
# API mode: Make HTTP request
endpoint = f"{self.api_url}/api/v1/search"
payload = {"query": query_text, "search_type": query_type.upper(), "top_k": top_k}
if datasets:
payload["datasets"] = datasets
if system_prompt:
payload["system_prompt"] = system_prompt
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
response.raise_for_status()
return response.json()
else:
# Direct mode: Call cognee directly
from cognee.modules.search.types import SearchType
with redirect_stdout(sys.stderr):
results = await self.cognee.search(
query_type=SearchType[query_type.upper()], query_text=query_text
)
return results
async def delete(self, data_id: UUID, dataset_id: UUID, mode: str = "soft") -> Dict[str, Any]:
"""
Delete data from a dataset.
Parameters
----------
data_id : UUID
ID of the data to delete
dataset_id : UUID
ID of the dataset containing the data
mode : str
Deletion mode ("soft" or "hard")
Returns
-------
Dict[str, Any]
Result of the deletion
"""
if self.use_api:
# API mode: Make HTTP request
endpoint = f"{self.api_url}/api/v1/delete"
params = {"data_id": str(data_id), "dataset_id": str(dataset_id), "mode": mode}
response = await self.client.delete(
endpoint, params=params, headers=self._get_headers()
)
response.raise_for_status()
return response.json()
else:
# Direct mode: Call cognee directly
from cognee.modules.users.methods import get_default_user
with redirect_stdout(sys.stderr):
user = await get_default_user()
result = await self.cognee.delete(
data_id=data_id, dataset_id=dataset_id, mode=mode, user=user
)
return result
async def prune_data(self) -> Dict[str, Any]:
"""
Prune all data from the knowledge graph.
Returns
-------
Dict[str, Any]
Result of the prune operation
"""
if self.use_api:
# Note: The API doesn't expose a prune endpoint, so we'll need to handle this
# For now, raise an error
raise NotImplementedError("Prune operation is not available via API")
else:
# Direct mode: Call cognee directly
with redirect_stdout(sys.stderr):
await self.cognee.prune.prune_data()
return {"status": "success", "message": "Data pruned successfully"}
async def prune_system(self, metadata: bool = True) -> Dict[str, Any]:
"""
Prune system data from the knowledge graph.
Parameters
----------
metadata : bool
Whether to prune metadata
Returns
-------
Dict[str, Any]
Result of the prune operation
"""
if self.use_api:
# Note: The API doesn't expose a prune endpoint
raise NotImplementedError("Prune system operation is not available via API")
else:
# Direct mode: Call cognee directly
with redirect_stdout(sys.stderr):
await self.cognee.prune.prune_system(metadata=metadata)
return {"status": "success", "message": "System pruned successfully"}
async def get_pipeline_status(self, dataset_ids: List[UUID], pipeline_name: str) -> str:
"""
Get the status of a pipeline run.
Parameters
----------
dataset_ids : List[UUID]
List of dataset IDs
pipeline_name : str
Name of the pipeline
Returns
-------
str
Status information
"""
if self.use_api:
# Note: This would need a custom endpoint on the API side
raise NotImplementedError("Pipeline status is not available via API")
else:
# Direct mode: Call cognee directly
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
with redirect_stdout(sys.stderr):
status = await get_pipeline_status(dataset_ids, pipeline_name)
return str(status)
async def list_datasets(self) -> List[Dict[str, Any]]:
"""
List all datasets.
Returns
-------
List[Dict[str, Any]]
List of datasets
"""
if self.use_api:
# API mode: Make HTTP request
endpoint = f"{self.api_url}/api/v1/datasets"
response = await self.client.get(endpoint, headers=self._get_headers())
response.raise_for_status()
return response.json()
else:
# Direct mode: Call cognee directly
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import get_datasets
with redirect_stdout(sys.stderr):
user = await get_default_user()
datasets = await get_datasets(user.id)
return [
{"id": str(d.id), "name": d.name, "created_at": str(d.created_at)}
for d in datasets
]
async def close(self):
"""Close the HTTP client if in API mode."""
if self.use_api and hasattr(self, "client"):
await self.client.aclose()

View file

@ -2,28 +2,27 @@ import json
import os
import sys
import argparse
import cognee
import asyncio
import subprocess
from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
import importlib.util
from contextlib import redirect_stdout
import mcp.types as types
from mcp.server import FastMCP
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder
from starlette.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import uvicorn
try:
from .cognee_client import CogneeClient
except ImportError:
from cognee_client import CogneeClient
try:
from cognee.tasks.codingagents.coding_rule_associations import (
@ -41,6 +40,8 @@ mcp = FastMCP("Cognee")
logger = get_logger()
cognee_client: Optional[CogneeClient] = None
async def run_sse_with_cors():
"""Custom SSE transport with CORS middleware."""
@ -141,11 +142,20 @@ async def cognee_add_developer_rules(
with redirect_stdout(sys.stderr):
logger.info(f"Starting cognify for: {file_path}")
try:
await cognee.add(file_path, node_set=["developer_rules"])
model = KnowledgeGraph
await cognee_client.add(file_path, node_set=["developer_rules"])
model = None
if graph_model_file and graph_model_name:
model = load_class(graph_model_file, graph_model_name)
await cognee.cognify(graph_model=model)
if cognee_client.use_api:
logger.warning(
"Custom graph models are not supported in API mode, ignoring."
)
else:
from cognee.shared.data_models import KnowledgeGraph
model = load_class(graph_model_file, graph_model_name)
await cognee_client.cognify(graph_model=model)
logger.info(f"Cognify finished for: {file_path}")
except Exception as e:
logger.error(f"Cognify failed for {file_path}: {str(e)}")
@ -255,7 +265,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks
@ -293,15 +303,20 @@ async def cognify(
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
logger.info("Cognify process starting.")
if graph_model_file and graph_model_name:
graph_model = load_class(graph_model_file, graph_model_name)
else:
graph_model = KnowledgeGraph
await cognee.add(data)
graph_model = None
if graph_model_file and graph_model_name:
if cognee_client.use_api:
logger.warning("Custom graph models are not supported in API mode, ignoring.")
else:
from cognee.shared.data_models import KnowledgeGraph
graph_model = load_class(graph_model_file, graph_model_name)
await cognee_client.add(data)
try:
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model)
logger.info("Cognify process finished.")
except Exception as e:
logger.error("Cognify process failed.")
@ -354,16 +369,19 @@ async def save_interaction(data: str) -> list:
with redirect_stdout(sys.stderr):
logger.info("Save interaction process starting.")
await cognee.add(data, node_set=["user_agent_interaction"])
await cognee_client.add(data, node_set=["user_agent_interaction"])
try:
await cognee.cognify()
await cognee_client.cognify()
logger.info("Save interaction process finished.")
logger.info("Generating associated rules from interaction data.")
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
logger.info("Associated rules generated from interaction data.")
# Rule associations only work in direct mode
if not cognee_client.use_api:
logger.info("Generating associated rules from interaction data.")
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
logger.info("Associated rules generated from interaction data.")
else:
logger.warning("Rule associations are not available in API mode, skipping.")
except Exception as e:
logger.error("Save interaction process failed.")
@ -420,11 +438,18 @@ async def codify(repo_path: str) -> list:
- All stdout is redirected to stderr to maintain MCP communication integrity
"""
if cognee_client.use_api:
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
async def codify_task(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
logger.info("Codify process starting.")
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
results = []
async for result in run_code_graph_pipeline(repo_path, False):
results.append(result)
@ -478,11 +503,6 @@ async def search(search_query: str, search_type: str) -> list:
Best for: Direct document retrieval, specific fact-finding.
Returns: LLM responses based on relevant text chunks.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
Returns: Formatted relationship data and entity connections.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -524,7 +544,6 @@ async def search(search_query: str, search_type: str) -> list:
- "RAG_COMPLETION": Returns an LLM response based on the search query and standard RAG data
- "CODE": Returns code-related knowledge in JSON format
- "CHUNKS": Returns raw text chunks from the knowledge graph
- "INSIGHTS": Returns relationships between nodes in readable format
- "SUMMARIES": Returns pre-generated hierarchical summaries
- "CYPHER": Direct graph database queries
- "FEELING_LUCKY": Automatically selects best search type
@ -537,7 +556,6 @@ async def search(search_query: str, search_type: str) -> list:
A list containing a single TextContent object with the search results.
The format of the result depends on the search_type:
- **GRAPH_COMPLETION/RAG_COMPLETION**: Conversational AI response strings
- **INSIGHTS**: Formatted relationship descriptions and entity connections
- **CHUNKS**: Relevant text passages with source metadata
- **SUMMARIES**: Hierarchical summaries from general to specific
- **CODE**: Structured code information with context
@ -547,7 +565,6 @@ async def search(search_query: str, search_type: str) -> list:
Performance & Optimization:
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
- **CHUNKS**: Fastest, pure vector similarity search without LLM
- **SUMMARIES**: Fast, returns pre-computed summaries
- **CODE**: Medium speed, specialized for code understanding
@ -574,23 +591,40 @@ async def search(search_query: str, search_type: str) -> list:
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
search_results = await cognee_client.search(
query_text=search_query, query_type=search_type
)
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif (
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
):
return str(search_results[0])
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
# Handle different result formats based on API vs direct mode
if cognee_client.use_api:
# API mode returns JSON-serialized results
if isinstance(search_results, str):
return search_results
elif isinstance(search_results, list):
if (
search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"]
and len(search_results) > 0
):
return str(search_results[0])
return str(search_results)
else:
return json.dumps(search_results, cls=JSONEncoder)
else:
return str(search_results)
# Direct mode processing
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif (
search_type.upper() == "GRAPH_COMPLETION"
or search_type.upper() == "RAG_COMPLETION"
):
return str(search_results[0])
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
else:
return str(search_results)
search_results = await search_task(search_query, search_type)
return [types.TextContent(type="text", text=search_results)]
@ -623,6 +657,10 @@ async def get_developer_rules() -> list:
async def fetch_rules_from_cognee() -> str:
"""Collect all developer rules from Cognee"""
with redirect_stdout(sys.stderr):
if cognee_client.use_api:
logger.warning("Developer rules retrieval is not available in API mode")
return "Developer rules retrieval is not available in API mode"
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
return developer_rules
@ -662,16 +700,24 @@ async def list_data(dataset_id: str = None) -> list:
with redirect_stdout(sys.stderr):
try:
user = await get_default_user()
output_lines = []
if dataset_id:
# List data for specific dataset
# Detailed data listing for specific dataset is only available in direct mode
if cognee_client.use_api:
return [
types.TextContent(
type="text",
text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.",
)
]
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import get_dataset, get_dataset_data
logger.info(f"Listing data for dataset: {dataset_id}")
dataset_uuid = UUID(dataset_id)
# Get the dataset information
from cognee.modules.data.methods import get_dataset, get_dataset_data
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_uuid)
@ -700,11 +746,9 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append(" (No data items in this dataset)")
else:
# List all datasets
# List all datasets - works in both modes
logger.info("Listing all datasets")
from cognee.modules.data.methods import get_datasets
datasets = await get_datasets(user.id)
datasets = await cognee_client.list_datasets()
if not datasets:
return [
@ -719,20 +763,21 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append("")
for i, dataset in enumerate(datasets, 1):
# Get data count for each dataset
from cognee.modules.data.methods import get_dataset_data
data_items = await get_dataset_data(dataset.id)
output_lines.append(f"{i}. 📁 {dataset.name}")
output_lines.append(f" Dataset ID: {dataset.id}")
output_lines.append(f" Created: {dataset.created_at}")
output_lines.append(f" Data items: {len(data_items)}")
# In API mode, dataset is a dict; in direct mode, it's formatted as dict
if isinstance(dataset, dict):
output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}")
output_lines.append(f" Dataset ID: {dataset.get('id')}")
output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}")
else:
output_lines.append(f"{i}. 📁 {dataset.name}")
output_lines.append(f" Dataset ID: {dataset.id}")
output_lines.append(f" Created: {dataset.created_at}")
output_lines.append("")
output_lines.append("💡 To see data items in a specific dataset, use:")
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
output_lines.append("")
if not cognee_client.use_api:
output_lines.append("💡 To see data items in a specific dataset, use:")
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
output_lines.append("")
output_lines.append("🗑️ To delete specific data, use:")
output_lines.append(' delete(data_id="data-id", dataset_id="dataset-id")')
@ -801,12 +846,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
data_uuid = UUID(data_id)
dataset_uuid = UUID(dataset_id)
# Get default user for the operation
user = await get_default_user()
# Call the cognee delete function
result = await cognee.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user
# Call the cognee delete function via client
result = await cognee_client.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode
)
logger.info(f"Delete operation completed successfully: {result}")
@ -853,11 +895,21 @@ async def prune():
-----
- This operation cannot be undone. All memory data will be permanently deleted.
- The function prunes both data content (using prune_data) and system metadata (using prune_system)
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
return [types.TextContent(type="text", text="Pruned")]
try:
await cognee_client.prune_data()
await cognee_client.prune_system(metadata=True)
return [types.TextContent(type="text", text="Pruned")]
except NotImplementedError:
error_msg = "❌ Prune operation is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Prune operation failed: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool()
@ -880,13 +932,26 @@ async def cognify_status():
- The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset"
- Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
user = await get_default_user()
status = await get_pipeline_status(
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
try:
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
user = await get_default_user()
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get cognify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool()
@ -909,13 +974,26 @@ async def codify_status():
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
- Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
user = await get_default_user()
status = await get_pipeline_status(
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
try:
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
user = await get_default_user()
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get codify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
def node_to_string(node):
@ -949,6 +1027,8 @@ def load_class(model_file, model_name):
async def main():
global cognee_client
parser = argparse.ArgumentParser()
parser.add_argument(
@ -992,12 +1072,30 @@ async def main():
help="Argument stops database migration from being attempted",
)
# Cognee API connection options
parser.add_argument(
"--api-url",
default=None,
help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). "
"If provided, the MCP server will connect to the API instead of using cognee directly.",
)
parser.add_argument(
"--api-token",
default=None,
help="Authentication token for the API (optional, required if API has authentication enabled).",
)
args = parser.parse_args()
# Initialize the global CogneeClient
cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token)
mcp.settings.host = args.host
mcp.settings.port = args.port
if not args.no_migration:
# Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
# Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...")
migration_result = subprocess.run(
@ -1020,6 +1118,8 @@ async def main():
sys.exit(1)
logger.info("Database migrations done.")
elif args.api_url:
logger.info("Skipping database migrations (using API mode)")
logger.info(f"Starting MCP server with transport: {args.transport}")
if args.transport == "stdio":

2831
cognee-mcp/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,7 @@ from .api.v1.add import add
from .api.v1.delete import delete
from .api.v1.cognify import cognify
from .modules.memify import memify
from .api.v1.update import update
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune

View file

@ -189,12 +189,12 @@ class HealthChecker:
start_time = time.time()
try:
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm import LLMGateway
config = get_llm_config()
# Test actual API connection with minimal request
LLMGateway.show_prompt("test", "test.txt")
from cognee.infrastructure.llm.utils import test_llm_connection
await test_llm_connection()
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
@ -217,13 +217,9 @@ class HealthChecker:
"""Check embedding service health (non-critical)."""
start_time = time.time()
try:
from cognee.infrastructure.databases.vector.embeddings.get_embedding_engine import (
get_embedding_engine,
)
from cognee.infrastructure.llm.utils import test_embedding_connection
# Test actual embedding generation with minimal text
engine = get_embedding_engine()
await engine.embed_text(["test"])
await test_embedding_connection()
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
@ -245,16 +241,6 @@ class HealthChecker:
"""Get comprehensive health status."""
components = {}
# Critical services
critical_components = [
"relational_db",
"vector_db",
"graph_db",
"file_storage",
"llm_provider",
"embedding_service",
]
critical_checks = [
("relational_db", self.check_relational_db()),
("vector_db", self.check_vector_db()),
@ -300,11 +286,11 @@ class HealthChecker:
else:
components[name] = result
critical_comps = [check[0] for check in critical_checks]
# Determine overall status
critical_unhealthy = any(
comp.status == HealthStatus.UNHEALTHY
comp.status == HealthStatus.UNHEALTHY and name in critical_comps
for name, comp in components.items()
if name in critical_components
)
has_degraded = any(comp.status == HealthStatus.DEGRADED for comp in components.values())

View file

@ -1,6 +1,8 @@
from uuid import UUID
from typing import Union, BinaryIO, List, Optional
import os
from typing import Union, BinaryIO, List, Optional, Dict, Any
from pydantic import BaseModel
from urllib.parse import urlparse
from cognee.modules.users.models import User
from cognee.modules.pipelines import Task, run_pipeline
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
@ -11,6 +13,19 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
)
from cognee.modules.engine.operations.setup import setup
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
from cognee.shared.logging_utils import get_logger
logger = get_logger()
try:
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
from cognee.context_global_variables import (
tavily_config as tavily,
soup_crawler_config as soup_crawler,
)
except ImportError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
async def add(
@ -23,12 +38,15 @@ async def add(
dataset_id: Optional[UUID] = None,
preferred_loaders: List[str] = None,
incremental_loading: bool = True,
extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None,
):
"""
Add data to Cognee for knowledge graph processing.
This is the first step in the Cognee workflow - it ingests raw data and prepares it
for processing. The function accepts various data formats including text, files, and
for processing. The function accepts various data formats including text, files, urls and
binary streams, then stores them in a specified dataset for further processing.
Prerequisites:
@ -68,6 +86,7 @@ async def add(
- S3 path: "s3://my-bucket/documents/file.pdf"
- List of mixed types: ["text content", "/path/file.pdf", "file://doc.txt", file_handle]
- Binary file object: open("file.txt", "rb")
- url: A web link url (https or http)
dataset_name: Name of the dataset to store data in. Defaults to "main_dataset".
Create separate datasets to organize different knowledge domains.
user: User object for authentication and permissions. Uses default user if None.
@ -78,6 +97,9 @@ async def add(
vector_db_config: Optional configuration for vector database (for custom setups).
graph_db_config: Optional configuration for graph database (for custom setups).
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup
tavily_config: Optional configuration for Tavily API, including API key and extraction settings
soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules.
Returns:
PipelineRunInfo: Information about the ingestion pipeline execution including:
@ -126,6 +148,21 @@ async def add(
# Add a single file
await cognee.add("/home/user/documents/analysis.pdf")
# Add a single url and bs4 extract ingestion method
extraction_rules = {
"title": "h1",
"description": "p",
"more_info": "a[href*='more-info']"
}
await cognee.add("https://example.com",extraction_rules=extraction_rules)
# Add a single url and tavily extract ingestion method
Make sure to set TAVILY_API_KEY = YOUR_TAVILY_API_KEY as a environment variable
await cognee.add("https://example.com")
# Add multiple urls
await cognee.add(["https://example.com","https://books.toscrape.com"])
```
Environment Variables:
@ -133,22 +170,55 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
- LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password
- VECTOR_DB_PROVIDER: "lancedb" (default), "chromadb", "pgvector"
- GRAPH_DATABASE_PROVIDER: "kuzu" (default), "neo4j"
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
"""
try:
if not soup_crawler_config and extraction_rules:
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
if not tavily_config and os.getenv("TAVILY_API_KEY"):
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
soup_crawler.set(soup_crawler_config)
tavily.set(tavily_config)
http_schemes = {"http", "https"}
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
return isinstance(item, str) and urlparse(item).scheme in http_schemes
if _is_http_url(data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
except NameError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
tasks = [
Task(resolve_data_directories, include_subdirectories=True),
Task(ingest_data, dataset_name, user, node_set, dataset_id, preferred_loaders),
Task(
ingest_data,
dataset_name,
user,
node_set,
dataset_id,
preferred_loaders,
),
]
await setup()
user, authorized_dataset = await resolve_authorized_user_dataset(dataset_id, dataset_name, user)
user, authorized_dataset = await resolve_authorized_user_dataset(
dataset_name=dataset_name, dataset_id=dataset_id, user=user
)
await reset_dataset_pipeline_run_status(
authorized_dataset.id, user, pipeline_names=["add_pipeline", "cognify_pipeline"]

View file

@ -73,7 +73,11 @@ def get_add_router() -> APIRouter:
try:
add_run = await cognee_add(
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
data,
datasetName,
user=user,
dataset_id=datasetId,
node_set=node_set if node_set else None,
)
if isinstance(add_run, PipelineRunErrored):

View file

@ -148,7 +148,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks

View file

@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
"type": "string",
"description": "Type of search to perform",
"enum": [
"INSIGHTS",
"CODE",
"GRAPH_COMPLETION",
"NATURAL_LANGUAGE",

View file

@ -59,7 +59,7 @@ async def handle_search(arguments: Dict[str, Any], user) -> list:
valid_search_types = (
search_tool["parameters"]["properties"]["search_type"]["enum"]
if search_tool
else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
else ["CODE", "GRAPH_COMPLETION", "NATURAL_LANGUAGE"]
)
if search_type_str not in valid_search_types:

View file

@ -14,7 +14,6 @@ DEFAULT_TOOLS = [
"type": "string",
"description": "Type of search to perform",
"enum": [
"INSIGHTS",
"CODE",
"GRAPH_COMPLETION",
"NATURAL_LANGUAGE",

View file

@ -52,11 +52,6 @@ async def search(
Best for: Direct document retrieval, specific fact-finding.
Returns: LLM responses based on relevant text chunks.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
Returns: Formatted relationship data and entity connections.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -124,9 +119,6 @@ async def search(
**GRAPH_COMPLETION/RAG_COMPLETION**:
[List of conversational AI response strings]
**INSIGHTS**:
[List of formatted relationship descriptions and entity connections]
**CHUNKS**:
[List of relevant text passages with source metadata]
@ -146,7 +138,6 @@ async def search(
Performance & Optimization:
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
- **CHUNKS**: Fastest, pure vector similarity search without LLM
- **SUMMARIES**: Fast, returns pre-computed summaries
- **CODE**: Medium speed, specialized for code understanding

View file

@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
class LLMConfigInputDTO(InDTO):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
provider: Union[
Literal["openai"],
Literal["ollama"],
Literal["anthropic"],
Literal["gemini"],
Literal["mistral"],
]
model: str
api_key: str

View file

@ -502,22 +502,48 @@ def start_ui(
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
cwd = os.getcwd()
env_file = os.path.join(cwd, ".env")
try:
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
subprocess.run(["docker", "pull", image], check=True)
import uuid
container_name = f"cognee-mcp-{uuid.uuid4().hex[:8]}"
docker_cmd = [
"docker",
"run",
"--name",
container_name,
"-p",
f"{mcp_port}:8000",
"--rm",
"-e",
"TRANSPORT_MODE=sse",
]
if start_backend:
docker_cmd.extend(
[
"-e",
f"API_URL=http://localhost:{backend_port}",
]
)
logger.info(
f"Configuring MCP to connect to backend API at http://localhost:{backend_port}"
)
logger.info("(localhost will be auto-converted to host.docker.internal)")
else:
cwd = os.getcwd()
env_file = os.path.join(cwd, ".env")
docker_cmd.extend(["--env-file", env_file])
docker_cmd.append(
image
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
mcp_process = subprocess.Popen(
[
"docker",
"run",
"-p",
f"{mcp_port}:8000",
"--rm",
"--env-file",
env_file,
"-e",
"TRANSPORT_MODE=sse",
"cognee/cognee-mcp:daulet-dev",
],
docker_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
@ -526,8 +552,13 @@ def start_ui(
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
pid_callback(mcp_process.pid)
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
# Pass both PID and container name using a tuple
pid_callback((mcp_process.pid, container_name))
mode_info = "API mode" if start_backend else "direct mode"
logger.info(
f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse ({mode_info})"
)
except Exception as e:
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
# Start backend server if requested
@ -627,7 +658,6 @@ def start_ui(
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
shell=True,
)
else:
@ -637,7 +667,6 @@ def start_ui(
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)

View file

@ -75,7 +75,7 @@ def get_update_router() -> APIRouter:
data=data,
dataset_id=dataset_id,
user=user,
node_set=node_set,
node_set=node_set if node_set else None,
)
# If any cognify run errored return JSONResponse with proper error status code

View file

@ -10,9 +10,9 @@ from cognee.api.v1.cognify import cognify
async def update(
data_id: UUID,
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_id: UUID,
user: User = None,
node_set: Optional[List[str]] = None,
dataset_id: Optional[UUID] = None,
vector_db_config: dict = None,
graph_db_config: dict = None,
preferred_loaders: List[str] = None,

View file

@ -175,19 +175,59 @@ def main() -> int:
# Handle UI flag
if hasattr(args, "start_ui") and args.start_ui:
spawned_pids = []
docker_container = None
def signal_handler(signum, frame):
"""Handle Ctrl+C and other termination signals"""
nonlocal spawned_pids
fmt.echo("\nShutting down UI server...")
nonlocal spawned_pids, docker_container
try:
fmt.echo("\nShutting down UI server...")
except (BrokenPipeError, OSError):
pass
# First, stop Docker container if running
if docker_container:
try:
result = subprocess.run(
["docker", "stop", docker_container],
capture_output=True,
timeout=10,
check=False,
)
try:
if result.returncode == 0:
fmt.success(f"✓ Docker container {docker_container} stopped.")
else:
fmt.warning(
f"Could not stop container {docker_container}: {result.stderr.decode()}"
)
except (BrokenPipeError, OSError):
pass
except subprocess.TimeoutExpired:
try:
fmt.warning(
f"Timeout stopping container {docker_container}, forcing removal..."
)
except (BrokenPipeError, OSError):
pass
subprocess.run(
["docker", "rm", "-f", docker_container], capture_output=True, check=False
)
except Exception:
pass
# Then, stop regular processes
for pid in spawned_pids:
try:
if hasattr(os, "killpg"):
# Unix-like systems: Use process groups
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
try:
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
except (BrokenPipeError, OSError):
pass
else:
# Windows: Use taskkill to terminate process and its children
subprocess.run(
@ -195,24 +235,35 @@ def main() -> int:
capture_output=True,
check=False,
)
fmt.success(f"✓ Process {pid} and its children terminated.")
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
fmt.warning(f"Could not terminate process {pid}: {e}")
try:
fmt.success(f"✓ Process {pid} and its children terminated.")
except (BrokenPipeError, OSError):
pass
except (OSError, ProcessLookupError, subprocess.SubprocessError):
pass
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, signal_handler)
try:
from cognee import start_ui
fmt.echo("Starting cognee UI...")
# Callback to capture PIDs of all spawned processes
def pid_callback(pid):
nonlocal spawned_pids
spawned_pids.append(pid)
# Callback to capture PIDs and Docker container of all spawned processes
def pid_callback(pid_or_tuple):
nonlocal spawned_pids, docker_container
# Handle both regular PIDs and (PID, container_name) tuples
if isinstance(pid_or_tuple, tuple):
pid, container_name = pid_or_tuple
spawned_pids.append(pid)
docker_container = container_name
else:
spawned_pids.append(pid_or_tuple)
frontend_port = 3000
start_backend, backend_port = True, 8000

View file

@ -70,11 +70,11 @@ After adding data, use `cognee cognify` to process it into knowledge graphs.
await cognee.add(data=data_to_add, dataset_name=args.dataset_name)
fmt.success(f"Successfully added data to dataset '{args.dataset_name}'")
except Exception as e:
raise CliCommandInnerException(f"Failed to add data: {str(e)}")
raise CliCommandInnerException(f"Failed to add data: {str(e)}") from e
asyncio.run(run_add())
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error adding data: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Failed to add data: {str(e)}", error_code=1) from e

View file

@ -107,7 +107,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
)
return result
except Exception as e:
raise CliCommandInnerException(f"Failed to cognify: {str(e)}")
raise CliCommandInnerException(f"Failed to cognify: {str(e)}") from e
result = asyncio.run(run_cognify())
@ -124,5 +124,5 @@ After successful cognify processing, use `cognee search` to query the knowledge
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error during cognification: {str(e)}", error_code=1) from e

View file

@ -79,8 +79,10 @@ Configuration changes will affect how cognee processes and stores data.
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error managing configuration: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(
f"Error managing configuration: {str(e)}", error_code=1
) from e
def _handle_get(self, args: argparse.Namespace) -> None:
try:
@ -122,7 +124,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.note("Configuration viewing not fully implemented yet")
except Exception as e:
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to get configuration: {str(e)}") from e
def _handle_set(self, args: argparse.Namespace) -> None:
try:
@ -141,7 +143,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.error(f"Failed to set configuration key '{args.key}'")
except Exception as e:
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to set configuration: {str(e)}") from e
def _handle_unset(self, args: argparse.Namespace) -> None:
try:
@ -189,7 +191,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.note("Use 'cognee config list' to see all available configuration options")
except Exception as e:
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to unset configuration: {str(e)}") from e
def _handle_list(self, args: argparse.Namespace) -> None:
try:
@ -209,7 +211,7 @@ Configuration changes will affect how cognee processes and stores data.
fmt.echo(" cognee config reset - Reset all to defaults")
except Exception as e:
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to list configuration: {str(e)}") from e
def _handle_reset(self, args: argparse.Namespace) -> None:
try:
@ -222,4 +224,4 @@ Configuration changes will affect how cognee processes and stores data.
fmt.echo("This would reset all settings to their default values")
except Exception as e:
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}")
raise CliCommandInnerException(f"Failed to reset configuration: {str(e)}") from e

View file

@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
from cognee.cli import DEFAULT_DOCS_URL
import cognee.cli.echo as fmt
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
class DeleteCommand(SupportsCliCommand):
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
return
# Build confirmation message
# If --force is used, skip the preview and go straight to deletion
if not args.force:
# --- START PREVIEW LOGIC ---
fmt.echo("Gathering data for preview...")
try:
preview_data = asyncio.run(
get_deletion_counts(
dataset_name=args.dataset_name,
user_id=args.user_id,
all_data=args.all,
)
)
except CliCommandException as e:
fmt.error(f"Error occured when fetching preview data: {str(e)}")
return
if not preview_data:
fmt.success("No data found to delete.")
return
fmt.echo("You are about to delete:")
fmt.echo(
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
)
fmt.echo("-" * 20)
# --- END PREVIEW LOGIC ---
# Build operation message for success/failure logging
if args.all:
confirm_msg = "Delete ALL data from cognee?"
operation = "all data"
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
elif args.user_id:
confirm_msg = f"Delete all data for user '{args.user_id}'?"
operation = f"data for user '{args.user_id}'"
else:
operation = "data"
# Confirm deletion unless forced
if not args.force:
fmt.warning("This operation is irreversible!")
if not fmt.confirm(confirm_msg):
@ -64,17 +93,20 @@ Be careful with deletion operations as they are irreversible.
# Run the async delete function
async def run_delete():
try:
# NOTE: The underlying cognee.delete() function is currently not working as expected.
# This is a separate bug that this preview feature helps to expose.
if args.all:
await cognee.delete(dataset_name=None, user_id=args.user_id)
else:
await cognee.delete(dataset_name=args.dataset_name, user_id=args.user_id)
except Exception as e:
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
raise CliCommandInnerException(f"Failed to delete: {str(e)}") from e
asyncio.run(run_delete())
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
fmt.success(f"Successfully deleted {operation}")
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error deleting data: {str(e)}", error_code=1) from e

View file

@ -31,10 +31,6 @@ Search Types & Use Cases:
Traditional RAG using document chunks without graph structure.
Best for: Direct document retrieval, specific fact-finding.
**INSIGHTS**:
Structured entity relationships and semantic connections.
Best for: Understanding concept relationships, knowledge mapping.
**CHUNKS**:
Raw text segments that match the query semantically.
Best for: Finding specific passages, citations, exact content.
@ -108,7 +104,7 @@ Search Types & Use Cases:
)
return results
except Exception as e:
raise CliCommandInnerException(f"Failed to search: {str(e)}")
raise CliCommandInnerException(f"Failed to search: {str(e)}") from e
results = asyncio.run(run_search())
@ -145,5 +141,5 @@ Search Types & Use Cases:
except Exception as e:
if isinstance(e, CliCommandInnerException):
raise CliCommandException(str(e), error_code=1)
raise CliCommandException(f"Error searching: {str(e)}", error_code=1)
raise CliCommandException(str(e), error_code=1) from e
raise CliCommandException(f"Error searching: {str(e)}", error_code=1) from e

View file

@ -19,7 +19,6 @@ COMMAND_DESCRIPTIONS = {
SEARCH_TYPE_CHOICES = [
"GRAPH_COMPLETION",
"RAG_COMPLETION",
"INSIGHTS",
"CHUNKS",
"SUMMARIES",
"CODE",

View file

@ -12,6 +12,8 @@ from cognee.modules.users.methods import get_user
# for different async tasks, threads and processes
vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
tavily_config = ContextVar("tavily_config", default=None)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):

View file

@ -50,26 +50,26 @@ class GraphConfig(BaseSettings):
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
# If no specific graph_filename or path are provided
@pydantic.model_validator(mode="after")
def fill_derived(cls, values):
provider = values.graph_database_provider.lower()
def fill_derived(self):
provider = self.graph_database_provider.lower()
base_config = get_base_config()
# Set default filename if no filename is provided
if not values.graph_filename:
values.graph_filename = f"cognee_graph_{provider}"
if not self.graph_filename:
self.graph_filename = f"cognee_graph_{provider}"
# Handle graph file path
if values.graph_file_path:
if self.graph_file_path:
# Check if absolute path is provided
values.graph_file_path = ensure_absolute_path(
os.path.join(values.graph_file_path, values.graph_filename)
self.graph_file_path = ensure_absolute_path(
os.path.join(self.graph_file_path, self.graph_filename)
)
else:
# Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
self.graph_file_path = os.path.join(databases_directory_path, self.graph_filename)
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -44,16 +44,14 @@ def create_graph_engine(
Parameters:
-----------
- graph_database_provider: The type of graph database provider to use (e.g., neo4j,
falkordb, kuzu).
- graph_database_url: The URL for the graph database instance. Required for neo4j
and falkordb providers.
- graph_database_provider: The type of graph database provider to use (e.g., neo4j, falkor, kuzu).
- graph_database_url: The URL for the graph database instance. Required for neo4j and falkordb providers.
- graph_database_username: The username for authentication with the graph database.
Required for neo4j provider.
- graph_database_password: The password for authentication with the graph database.
Required for neo4j provider.
- graph_database_port: The port number for the graph database connection. Required
for the falkordb provider.
for the falkordb provider
- graph_file_path: The filesystem path to the graph file. Required for the kuzu
provider.
@ -86,21 +84,6 @@ def create_graph_engine(
graph_database_name=graph_database_name or None,
)
elif graph_database_provider == "falkordb":
if not (graph_database_url and graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url=graph_database_url,
database_port=graph_database_port,
embedding_engine=embedding_engine,
)
elif graph_database_provider == "kuzu":
if not graph_file_path:
raise EnvironmentError("Missing required Kuzu database path.")
@ -179,5 +162,5 @@ def create_graph_engine(
raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
)

View file

@ -48,6 +48,29 @@ class KuzuAdapter(GraphDBInterface):
def _initialize_connection(self) -> None:
"""Initialize the Kuzu database connection and schema."""
def _install_json_extension():
"""
Function handles installing of the json extension for the current Kuzu version.
This has to be done with an empty graph db before connecting to an existing database otherwise
missing json extension errors will be raised.
"""
try:
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
temp_graph_file = temp_file.name
tmp_db = Database(
temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
tmp_db.init_database()
connection = Connection(tmp_db)
connection.execute("INSTALL JSON;")
except Exception as e:
logger.info(f"JSON extension already installed or not needed: {e}")
_install_json_extension()
try:
if "s3://" in self.db_path:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
@ -109,11 +132,6 @@ class KuzuAdapter(GraphDBInterface):
self.db.init_database()
self.connection = Connection(self.db)
try:
self.connection.execute("INSTALL JSON;")
except Exception as e:
logger.info(f"JSON extension already installed or not needed: {e}")
try:
self.connection.execute("LOAD EXTENSION JSON;")
logger.info("Loaded JSON extension")

View file

@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
auth=auth,
max_connection_lifetime=120,
notifications_min_severity="OFF",
keep_alive=True,
)
async def initialize(self) -> None:
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
{
"node_id": str(node.id),
"label": type(node).__name__,
"properties": self.serialize_properties(node.model_dump()),
"properties": self.serialize_properties(dict(node)),
}
for node in nodes
]

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("deadlock_retry")
def deadlock_retry(max_retries=5):
def deadlock_retry(max_retries=10):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.

View file

@ -53,7 +53,7 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
return graph_id, region
except Exception as e:
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") from e
def validate_graph_id(graph_id: str) -> bool:

View file

@ -23,14 +23,14 @@ class RelationalConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@pydantic.model_validator(mode="after")
def fill_derived(cls, values):
def fill_derived(self):
# Set file path based on graph database provider if no file path is provided
if not values.db_path:
if not self.db_path:
base_config = get_base_config()
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.db_path = databases_directory_path
self.db_path = databases_directory_path
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -283,7 +283,7 @@ class SQLAlchemyAdapter:
try:
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
except (ValueError, NoResultFound) as e:
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
raise EntityNotFoundError(message=f"Entity not found: {str(e)}") from e
# Check if other data objects point to the same raw data location
raw_data_location_entities = (

View file

@ -30,21 +30,21 @@ class VectorConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@pydantic.model_validator(mode="after")
def validate_paths(cls, values):
def validate_paths(self):
base_config = get_base_config()
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
if values.vector_db_url and Path(values.vector_db_url).exists():
if self.vector_db_url and Path(self.vector_db_url).exists():
# Relative path to absolute
values.vector_db_url = ensure_absolute_path(
values.vector_db_url,
self.vector_db_url = ensure_absolute_path(
self.vector_db_url,
)
elif not values.vector_db_url:
elif not self.vector_db_url:
# Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
self.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
return values
return self
def to_dict(self) -> dict:
"""

View file

@ -19,8 +19,7 @@ def create_vector_engine(
for each provider, raising an EnvironmentError if any are missing, or ImportError if the
ChromaDB package is not installed.
Supported providers include: pgvector, FalkorDB, ChromaDB, and
LanceDB.
Supported providers include: pgvector, ChromaDB, and LanceDB.
Parameters:
-----------
@ -79,18 +78,6 @@ def create_vector_engine(
embedding_engine,
)
elif vector_db_provider == "falkordb":
if not (vector_db_url and vector_db_port):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url=vector_db_url,
database_port=vector_db_port,
embedding_engine=embedding_engine,
)
elif vector_db_provider == "chromadb":
try:
import chromadb

View file

@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
- int: An integer representing the number of dimensions in the embedding vector.
"""
raise NotImplementedError()
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
raise NotImplementedError()

View file

@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
model: Optional[str] = "openai/text-embedding-3-large",
dimensions: Optional[int] = 3072,
max_completion_tokens: int = 512,
batch_size: int = 100,
):
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.batch_size = batch_size
# self.retry_count = 0
self.embedding_model = TextEmbedding(model_name=model)
@ -88,7 +90,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
except Exception as error:
logger.error(f"Embedding error in FastembedEmbeddingEngine: {str(error)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}")
raise EmbeddingException(
f"Failed to index data points using model {self.model}"
) from error
def get_vector_size(self) -> int:
"""
@ -101,6 +105,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Instantiate and return the tokenizer used for preparing text for embedding.

View file

@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
endpoint: str = None,
api_version: str = None,
max_completion_tokens: int = 512,
batch_size: int = 100,
):
self.api_key = api_key
self.endpoint = endpoint
@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
self.batch_size = batch_size
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
@ -148,7 +150,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
litellm.exceptions.NotFoundError,
) as e:
logger.error(f"Embedding error with model {self.model}: {str(e)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}")
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
except Exception as error:
logger.error("Error embedding text: %s", str(error))
@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Load and return the appropriate tokenizer for the specified model based on the provider.
@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
# count tokens as we calculate tokens word by word
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
# tokenizer = GeminiTokenizer(
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
# )
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens

View file

@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
max_completion_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
batch_size: int = 100,
):
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.batch_size = batch_size
self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Load and return a HuggingFace tokenizer for the embedding engine.

View file

@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_completion_tokens: Optional[int] = 8191
embedding_batch_size: Optional[int] = None
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def model_post_init(self, __context) -> None:
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
self.embedding_batch_size = 2048
elif not self.embedding_batch_size:
self.embedding_batch_size = 100
def to_dict(self) -> dict:
"""
Serialize all embedding configuration settings to a dictionary.

View file

@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
config.embedding_endpoint,
config.embedding_api_key,
config.embedding_api_version,
config.embedding_batch_size,
config.huggingface_tokenizer,
llm_config.llm_api_key,
llm_config.llm_provider,
@ -46,6 +47,7 @@ def create_embedding_engine(
embedding_endpoint,
embedding_api_key,
embedding_api_version,
embedding_batch_size,
huggingface_tokenizer,
llm_api_key,
llm_provider,
@ -84,6 +86,7 @@ def create_embedding_engine(
model=embedding_model,
dimensions=embedding_dimensions,
max_completion_tokens=embedding_max_completion_tokens,
batch_size=embedding_batch_size,
)
if embedding_provider == "ollama":
@ -95,6 +98,7 @@ def create_embedding_engine(
max_completion_tokens=embedding_max_completion_tokens,
endpoint=embedding_endpoint,
huggingface_tokenizer=huggingface_tokenizer,
batch_size=embedding_batch_size,
)
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
@ -108,4 +112,5 @@ def create_embedding_engine(
model=embedding_model,
dimensions=embedding_dimensions,
max_completion_tokens=embedding_max_completion_tokens,
batch_size=embedding_batch_size,
)

View file

@ -125,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),

View file

@ -39,7 +39,7 @@ class LLMConfig(BaseSettings):
structured_output_framework: str = "instructor"
llm_provider: str = "openai"
llm_model: str = "openai/gpt-4o-mini"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
@ -48,7 +48,7 @@ class LLMConfig(BaseSettings):
llm_max_completion_tokens: int = 16384
baml_llm_provider: str = "openai"
baml_llm_model: str = "gpt-4o-mini"
baml_llm_model: str = "gpt-5-mini"
baml_llm_endpoint: str = ""
baml_llm_api_key: Optional[str] = None
baml_llm_temperature: float = 0.0

View file

@ -10,8 +10,6 @@ Here are the available `SearchType` tools and their specific functions:
- Summarizing large amounts of information
- Quick understanding of complex subjects
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
**Best for:**
- Discovering how entities are connected
@ -95,9 +93,6 @@ Here are the available `SearchType` tools and their specific functions:
Query: "Summarize the key findings from these research papers"
Response: `SUMMARIES`
Query: "What is the relationship between the methodologies used in these papers?"
Response: `INSIGHTS`
Query: "When was Einstein born?"
Response: `CHUNKS`

View file

@ -1,115 +1,155 @@
import litellm
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
"""Adapter for Generic API LLM provider API"""
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
logger = get_logger()
observe = get_observe()
class GeminiAdapter(LLMInterface):
"""
Handles interactions with a language model API.
Adapter for Gemini API LLM provider.
Public methods include:
- acreate_structured_output
- show_prompt
This class initializes the API adapter with necessary credentials and configurations for
interacting with the gemini LLM models. It provides methods for creating structured outputs
based on user input and system prompts.
Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
"""
MAX_RETRIES = 5
name: str
model: str
api_key: str
def __init__(
self,
endpoint,
api_key: str,
model: str,
api_version: str,
max_completion_tokens: int,
endpoint: Optional[str] = None,
api_version: Optional[str] = None,
streaming: bool = False,
) -> None:
self.api_key = api_key
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation")
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate structured output from the language model based on the provided input and
system prompt.
Generate a response from a user query.
This method handles retries and raises a ValueError if the request fails or the response
does not conform to the expected schema, logging errors accordingly.
This asynchronous method sends a user query and a system prompt to a language model and
retrieves the generated response. It handles API communication and retries up to a
specified limit in case of request failures.
Parameters:
-----------
- text_input (str): The user input text to generate a response for.
- system_prompt (str): The system's prompt or context to influence the language
model's generation.
- response_model (Type[BaseModel]): A model type indicating the expected format of
the response.
- text_input (str): The input text from the user to generate a response for.
- system_prompt (str): A prompt that provides context or instructions for the
response generation.
- response_model (Type[BaseModel]): A Pydantic model that defines the structure of
the expected response.
Returns:
--------
- BaseModel: Returns the generated response as an instance of the specified response
model.
- BaseModel: An instance of the specified response model containing the structured
output from the language model.
"""
try:
if response_model is str:
response_schema = {"type": "string"}
else:
response_schema = response_model
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=5,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text_input},
]
try:
response = await acompletion(
model=f"{self.model}",
messages=messages,
api_key=self.api_key,
max_completion_tokens=self.max_completion_tokens,
temperature=0.1,
response_format=response_schema,
timeout=100,
num_retries=self.MAX_RETRIES,
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
if response_model is str:
return content
return response_model.model_validate_json(content)
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
raise ValueError("Failed to get valid response after retries")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -6,7 +6,7 @@ from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
@ -56,9 +56,7 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
@ -102,6 +100,7 @@ class GenericAPIAdapter(LLMInterface):
},
],
max_retries=5,
api_key=self.api_key,
api_base=self.endpoint,
response_model=response_model,
)
@ -119,7 +118,7 @@ class GenericAPIAdapter(LLMInterface):
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error
try:
return await self.aclient.chat.completions.create(
@ -152,4 +151,4 @@ class GenericAPIAdapter(LLMInterface):
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error

View file

@ -23,6 +23,7 @@ class LLMProvider(Enum):
- ANTHROPIC: Represents the Anthropic provider.
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider.
"""
OPENAI = "openai"
@ -30,6 +31,7 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
GEMINI = "gemini"
MISTRAL = "mistral"
def get_llm_client(raise_api_key_error: bool = True):
@ -143,7 +145,36 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
streaming=llm_config.llm_streaming,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
else:

View file

@ -0,0 +1,129 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
logger = get_logger()
observe = get_observe()
class MistralAdapter(LLMInterface):
"""
Adapter for Mistral AI API, for structured output generation and prompt display.
Public methods:
- acreate_structured_output
- show_prompt
"""
name = "Mistral"
model: str
api_key: str
max_completion_tokens: int
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode.MISTRAL_TOOLS,
api_key=get_llm_config().llm_api_key,
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate a response from the user query.
Parameters:
-----------
- text_input (str): The input text from the user to be processed.
- system_prompt (str): A prompt that sets the context for the query.
- response_model (Type[BaseModel]): The model to structure the response according to
its format.
Returns:
--------
- BaseModel: An instance of BaseModel containing the structured response.
"""
try:
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": f"""Use the given format to extract information
from the following input: {text_input}""",
},
]
try:
response = await self.aclient.chat.completions.create(
model=self.model,
max_tokens=self.max_completion_tokens,
max_retries=5,
messages=messages,
response_model=response_model,
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
return response_model.model_validate_json(content)
else:
raise ValueError("Failed to get valid response after retries")
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -5,15 +5,13 @@ from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
MissingSystemPromptPathError,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
@ -148,11 +146,11 @@ class OpenAIAdapter(LLMInterface):
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
):
) as e:
if not (self.fallback_model and self.fallback_api_key):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from e
try:
return await self.aclient.chat.completions.create(
@ -185,7 +183,7 @@ class OpenAIAdapter(LLMInterface):
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
) from error
@observe
@sleep_and_retry_sync()

View file

@ -3,6 +3,7 @@ from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
# NOTE: DEPRECATED as to count tokens you need to send an API request to Google it is too slow to use with Cognee
class GeminiTokenizer(TokenizerInterface):
"""
Implements a tokenizer interface for the Gemini model, managing token extraction and
@ -16,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
llm_model: str,
max_completion_tokens: int = 3072,
):
self.model = model
self.llm_model = llm_model
self.max_completion_tokens = max_completion_tokens
# Get LLM API key from config
@ -28,12 +29,11 @@ class GeminiTokenizer(TokenizerInterface):
get_llm_config,
)
config = get_embedding_config()
llm_config = get_llm_config()
import google.generativeai as genai
from google import genai
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
self.client = genai.Client(api_key=llm_config.llm_api_key)
def extract_tokens(self, text: str) -> List[Any]:
"""
@ -77,6 +77,7 @@ class GeminiTokenizer(TokenizerInterface):
- int: The number of tokens in the given text.
"""
import google.generativeai as genai
return len(genai.embed_content(model=f"models/{self.model}", content=text))
tokens_response = self.client.models.count_tokens(model=self.llm_model, contents=text)
return tokens_response.total_tokens

View file

@ -27,11 +27,11 @@ class LoaderEngine:
self.default_loader_priority = [
"text_loader",
"advanced_pdf_loader",
"pypdf_loader",
"image_loader",
"audio_loader",
"unstructured_loader",
"advanced_pdf_loader",
]
def register_loader(self, loader: LoaderInterface) -> bool:

View file

@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
async with db_engine.get_async_session() as session:
result = await session.execute(
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
select(Data)
.join(Data.datasets)
.filter((Dataset.id == dataset_id))
.order_by(Data.data_size.desc())
)
data = list(result.scalars().all())

View file

@ -0,0 +1,92 @@
from uuid import UUID
from cognee.cli.exceptions import CliCommandException
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, Data, DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_user
from dataclasses import dataclass
@dataclass
class DeletionCountsPreview:
datasets: int = 0
data_entries: int = 0
users: int = 0
async def get_deletion_counts(
dataset_name: str = None, user_id: str = None, all_data: bool = False
) -> DeletionCountsPreview:
"""
Calculates the number of items that will be deleted based on the provided arguments.
"""
counts = DeletionCountsPreview()
relational_engine = get_relational_engine()
async with relational_engine.get_async_session() as session:
if dataset_name:
# Find the dataset by name
dataset_result = await session.execute(
select(Dataset).where(Dataset.name == dataset_name)
)
dataset = dataset_result.scalar_one_or_none()
if dataset is None:
raise CliCommandException(
f"No Dataset exists with the name {dataset_name}", error_code=1
)
# Count data entries linked to this dataset
count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id == dataset.id)
)
data_entry_count = (await session.execute(count_query)).scalar_one()
counts.users = 1
counts.datasets = 1
counts.entries = data_entry_count
return counts
elif all_data:
# Simplified logic: Get total counts directly from the tables.
counts.datasets = (
await session.execute(select(func.count()).select_from(Dataset))
).scalar_one()
counts.entries = (
await session.execute(select(func.count()).select_from(Data))
).scalar_one()
counts.users = (
await session.execute(select(func.count()).select_from(User))
).scalar_one()
return counts
# Placeholder for user_id logic
elif user_id:
user = None
try:
user_uuid = UUID(user_id)
user = await get_user(user_uuid)
except (ValueError, EntityNotFoundError):
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
counts.users = 1
# Find all datasets owned by this user
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
user_datasets = (await session.execute(datasets_query)).scalars().all()
dataset_count = len(user_datasets)
counts.datasets = dataset_count
if dataset_count > 0:
dataset_ids = [d.id for d in user_datasets]
# Count all data entries across all of the user's datasets
data_count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id.in_(dataset_ids))
)
data_entry_count = (await session.execute(data_count_query)).scalar_one()
counts.entries = data_entry_count
else:
counts.entries = 0
return counts

View file

@ -5,7 +5,6 @@ from typing import Optional
class TableRow(DataPoint):
name: str
is_a: Optional[TableType] = None
description: str
properties: str

View file

@ -1,47 +1,31 @@
from uuid import UUID
from typing import Optional
from cognee.api.v1.exceptions import DatasetNotFoundError
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import (
create_authorized_dataset,
get_authorized_dataset,
get_authorized_dataset_by_name,
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
resolve_authorized_user_datasets,
)
async def resolve_authorized_user_dataset(dataset_id: UUID, dataset_name: str, user: User):
async def resolve_authorized_user_dataset(
dataset_name: str, dataset_id: Optional[UUID] = None, user: Optional[User] = None
):
"""
Function handles creation and dataset authorization if dataset already exist for Cognee.
Verifies that provided user has necessary permission for provided Dataset.
If Dataset does not exist creates the Dataset and gives permission for the user creating the dataset.
Args:
dataset_id: Id of the dataset.
dataset_name: Name of the dataset.
dataset_id: Id of the dataset.
user: Cognee User request is being processed for, if None default user will be used.
Returns:
Tuple[User, Dataset]: A tuple containing the user and the authorized dataset.
"""
if not user:
user = await get_default_user()
if dataset_id:
authorized_dataset = await get_authorized_dataset(user, dataset_id, "write")
elif dataset_name:
authorized_dataset = await get_authorized_dataset_by_name(dataset_name, user, "write")
user, authorized_datasets = await resolve_authorized_user_datasets(
datasets=dataset_id if dataset_id else dataset_name, user=user
)
if not authorized_dataset:
authorized_dataset = await create_authorized_dataset(
dataset_name=dataset_name, user=user
)
else:
raise ValueError("Either dataset_id or dataset_name must be provided.")
if not authorized_dataset:
raise DatasetNotFoundError(
message=f"Dataset ({str(dataset_id) or dataset_name}) not found."
)
return user, authorized_dataset
return user, authorized_datasets[0]

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Union, Tuple, List
from typing import Union, Tuple, List, Optional
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
@ -13,7 +13,7 @@ from cognee.modules.data.methods import (
async def resolve_authorized_user_datasets(
datasets: Union[str, UUID, list[str], list[UUID]], user: User = None
datasets: Union[str, UUID, list[str], list[UUID]], user: Optional[User] = None
) -> Tuple[User, List[Dataset]]:
"""
Function handles creation and dataset authorization if datasets already exist for Cognee.

View file

@ -4,35 +4,28 @@ import asyncio
from uuid import UUID
from typing import Any, List
from functools import wraps
from sqlalchemy import select
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
from cognee.modules.data.models import Data
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
from cognee.tasks.ingestion import resolve_data_directories
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunStarted,
PipelineRunYield,
PipelineRunAlreadyCompleted,
)
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task
@ -68,176 +61,6 @@ async def run_tasks(
context: dict = None,
incremental_loading: bool = False,
):
async def _run_tasks_data_item_incremental(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
):
db_engine = get_relational_engine()
# If incremental_loading of data is set to True don't process documents already processed by pipeline
# If data is being added to Cognee for the first time calculate the id of the data
if not isinstance(data_item, Data):
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
# Check pipeline status, if Data already processed for pipeline before skip current processing
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point:
if (
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
):
yield {
"run_info": PipelineRunAlreadyCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
return
try:
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
# Update pipeline status for Data element
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = {
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point)
await session.commit()
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
except Exception as error:
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
logger.error(
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
)
yield {
"run_info": PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
raise error
async def _run_tasks_data_item_regular(
data_item,
dataset,
tasks,
pipeline_id,
pipeline_run_id,
context,
user,
):
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
}
async def _run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
):
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
result = None
if incremental_loading:
async for result in _run_tasks_data_item_incremental(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
else:
async for result in _run_tasks_data_item_regular(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
return result
if not user:
user = await get_default_user()
@ -269,7 +92,7 @@ async def run_tasks(
# Create async tasks per data item that will run the pipeline for the data item
data_item_tasks = [
asyncio.create_task(
_run_tasks_data_item(
run_tasks_data_item(
data_item,
dataset,
tasks,

View file

@ -0,0 +1,261 @@
"""
Data item processing functions for pipeline operations.
This module contains reusable functions for processing individual data items
within pipeline operations, supporting both incremental and regular processing modes.
"""
import os
from typing import Any, Dict, AsyncGenerator, Optional
from sqlalchemy import select
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.data.models import Data, Dataset
from cognee.tasks.ingestion import save_data_item_to_storage
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunYield,
PipelineRunAlreadyCompleted,
)
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
from cognee.modules.pipelines.operations.run_tasks_with_telemetry import run_tasks_with_telemetry
from ..tasks.task import Task
logger = get_logger("run_tasks_data_item")
async def run_tasks_data_item_incremental(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process a single data item with incremental loading support.
This function handles incremental processing by checking if the data item
has already been processed for the given pipeline and dataset. If it has,
it skips processing and returns a completion status.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_name: Name of the pipeline
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
Yields:
Dict containing run_info and data_id for each processing step
"""
db_engine = get_relational_engine()
# If incremental_loading of data is set to True don't process documents already processed by pipeline
# If data is being added to Cognee for the first time calculate the id of the data
if not isinstance(data_item, Data):
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
# Check pipeline status, if Data already processed for pipeline before skip current processing
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point:
if (
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
):
yield {
"run_info": PipelineRunAlreadyCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
return
try:
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
# Update pipeline status for Data element
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = {
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point)
await session.commit()
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
except Exception as error:
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
logger.error(
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
)
yield {
"run_info": PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
raise error
async def run_tasks_data_item_regular(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process a single data item in regular (non-incremental) mode.
This function processes a data item without checking for previous processing
status, executing all tasks on the data item.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
Yields:
Dict containing run_info for each processing step
"""
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
}
async def run_tasks_data_item(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
incremental_loading: bool,
) -> Optional[Dict[str, Any]]:
"""
Process a single data item, choosing between incremental and regular processing.
This is the main entry point for data item processing that delegates to either
incremental or regular processing based on the incremental_loading flag.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_name: Name of the pipeline
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
incremental_loading: Whether to use incremental processing
Returns:
Dict containing the final processing result, or None if processing was skipped
"""
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
result = None
if incremental_loading:
async for result in run_tasks_data_item_incremental(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
else:
async for result in run_tasks_data_item_regular(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
return result

View file

@ -3,49 +3,96 @@ try:
except ModuleNotFoundError:
modal = None
from typing import Any, List, Optional
from uuid import UUID
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.models import (
PipelineRunStarted,
PipelineRunYield,
PipelineRunCompleted,
PipelineRunErrored,
)
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from cognee.modules.users.models import User
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import resolve_data_directories
from .run_tasks_data_item import run_tasks_data_item
logger = get_logger("run_tasks_distributed()")
if modal:
import os
from distributed.app import app
from distributed.modal_image import image
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=50,
secrets=[modal.Secret.from_name("distributed_cognee")],
secrets=[modal.Secret.from_name(secret_name)],
)
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
async def run_tasks_on_modal(
data_item,
dataset_id: UUID,
tasks: List[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[dict],
user: User,
incremental_loading: bool,
):
"""
Wrapper that runs the run_tasks_data_item function.
This is the function/code that runs on modal executor and produces the graph/vector db objects
"""
from cognee.infrastructure.databases.relational import get_relational_engine
run_info = None
async with get_relational_engine().get_async_session() as session:
from cognee.modules.data.models import Dataset
async for pipeline_run_info in pipeline_run:
run_info = pipeline_run_info
dataset = await session.get(Dataset, dataset_id)
return run_info
result = await run_tasks_data_item(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
incremental_loading=incremental_loading,
)
return result
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
async def run_tasks_distributed(
tasks: List[Task],
dataset_id: UUID,
data: List[Any] = None,
user: User = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = False,
):
if not user:
user = await get_default_user()
# Get dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
@ -53,9 +100,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
@ -65,30 +110,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
payload=data,
)
data_count = len(data) if isinstance(data, list) else 1
try:
if not isinstance(data, list):
data = [data]
arguments = [
[tasks] * data_count,
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
[user] * data_count,
[pipeline_name] * data_count,
[context] * data_count,
]
data = await resolve_data_directories(data)
async for result in run_tasks_on_modal.map.aio(*arguments):
logger.info(f"Received result: {result}")
number_of_data_items = len(data) if isinstance(data, list) else 1
yield PipelineRunYield(
data_item_tasks = [
data,
[dataset.id] * number_of_data_items,
[tasks] * number_of_data_items,
[pipeline_name] * number_of_data_items,
[pipeline_id] * number_of_data_items,
[pipeline_run_id] * number_of_data_items,
[context] * number_of_data_items,
[user] * number_of_data_items,
[incremental_loading] * number_of_data_items,
]
results = []
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
if not result:
continue
results.append(result)
# Remove skipped results
results = [r for r in results if r]
# If any data item failed, raise PipelineRunFailedError
errored = [
r
for r in results
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
]
if errored:
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
data_ingestion_info=results,
)
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
yield PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=locals().get("results"),
)
if not isinstance(error, PipelineRunFailedError):
raise

View file

@ -194,7 +194,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
await add_data_points(data_points=[cognee_user_interaction])
relationships = []
relationship_name = "used_graph_element_to_answer"

View file

@ -1,133 +0,0 @@
import asyncio
from typing import Any, Optional
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("InsightsRetriever")
class InsightsRetriever(BaseGraphRetriever):
"""
Retriever for handling graph connection-based insights.
Public methods include:
- get_context
- get_completion
Instance variables include:
- exploration_levels
- top_k
"""
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k
async def get_context(self, query: str) -> list:
"""
Find neighbours of a given node in the graph.
If the provided query does not correspond to an existing node,
search for similar entities and retrieve their connections.
Reraises NoDataError if there is no data found in the system.
Parameters:
-----------
- query (str): A string identifier for the node whose neighbours are to be
retrieved.
Returns:
--------
- list: A list of unique connections found for the queried node.
"""
if query is None:
return []
node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
try:
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
except CollectionNotFoundError as error:
logger.error("Entity collections not found")
raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
if len(relevant_results) == 0:
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
unique_node_connections_map = {}
unique_node_connections = []
for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections
# return [
# Edge(
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
# attributes={
# **connection[1],
# "relationship_type": connection[1]["relationship_name"],
# },
# )
# for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
Returns the graph connections context.
If a context is not provided, it fetches the context using the query provided.
Parameters:
-----------
- query (str): A string identifier used to fetch the context.
- context (Optional[Any]): An optional context to use for the completion; if None,
it fetches the context based on the query. (default None)
Returns:
--------
- Any: The context used for the completion, which is either provided or fetched
based on the query.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage import add_data_points, index_graph_edges
logger = get_logger("CompletionRetriever")
@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
await add_data_points(data_points=[cognee_user_feedback])
relationships = []
relationship_name = "gives_feedback_to"
@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
await graph_engine.apply_feedback_weight(
node_ids=to_node_ids, weight=feedback_sentiment.score
)

View file

@ -62,7 +62,7 @@ async def code_description_to_code_part(
try:
if include_docs:
search_results = await search(query_text=query, query_type="INSIGHTS")
search_results = await search(query_text=query, query_type="GRAPH_COMPLETION")
concatenated_descriptions = " ".join(
obj["description"]

View file

@ -9,7 +9,6 @@ from cognee.modules.search.exceptions import UnsupportedSearchTypeError
# Retrievers
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -44,10 +43,6 @@ async def get_search_type_tools(
SummariesRetriever(top_k=top_k).get_completion,
SummariesRetriever(top_k=top_k).get_context,
],
SearchType.INSIGHTS: [
InsightsRetriever(top_k=top_k).get_completion,
InsightsRetriever(top_k=top_k).get_context,
],
SearchType.CHUNKS: [
ChunksRetriever(top_k=top_k).get_completion,
ChunksRetriever(top_k=top_k).get_context,

View file

@ -19,7 +19,9 @@ from cognee.modules.search.types import (
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.data.methods.get_authorized_existing_datasets import (
get_authorized_existing_datasets,
)
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
@ -202,7 +204,9 @@ async def authorized_search(
Not to be used outside of active access control mode.
"""
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
search_datasets = await get_authorized_existing_datasets(
datasets=dataset_ids, permission_type="read", user=user
)
if use_combined_context:
search_responses = await search_in_datasets_context(

View file

@ -3,7 +3,6 @@ from enum import Enum
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
RAG_COMPLETION = "RAG_COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"

View file

@ -15,6 +15,7 @@ class ModelName(Enum):
ollama = "ollama"
anthropic = "anthropic"
gemini = "gemini"
mistral = "mistral"
class LLMConfig(BaseModel):
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
"value": "gemini",
"label": "Gemini",
},
{
"value": "mistral",
"label": "Mistral",
},
]
return SettingsDict.model_validate(
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
"label": "Gemini 2.0 Flash",
},
],
"mistral": [
{
"value": "mistral-medium-2508",
"label": "Mistral Medium 3.1",
},
{
"value": "magistral-medium-2509",
"label": "Magistral Medium 1.2",
},
{
"value": "magistral-medium-2507",
"label": "Magistral Medium 1.1",
},
{
"value": "mistral-large-2411",
"label": "Mistral Large 2.1",
},
],
},
},
vector_db={

View file

@ -37,6 +37,8 @@ async def get_authenticated_user(
except Exception as e:
# Convert any get_default_user failure into a proper HTTP 500 error
logger.error(f"Failed to create default user: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create default user: {str(e)}"
) from e
return user

View file

@ -40,8 +40,8 @@ async def create_role(
# Add association directly to the association table
role = Role(name=role_name, tenant_id=tenant.id)
session.add(role)
except IntegrityError:
raise EntityAlreadyExistsError(message="Role already exists for tenant.")
except IntegrityError as e:
raise EntityAlreadyExistsError(message="Role already exists for tenant.") from e
await session.commit()
await session.refresh(role)

View file

@ -35,5 +35,5 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
await session.merge(user)
await session.commit()
return tenant.id
except IntegrityError:
raise EntityAlreadyExistsError(message="Tenant already exists.")
except IntegrityError as e:
raise EntityAlreadyExistsError(message="Tenant already exists.") from e

View file

@ -288,7 +288,6 @@ class SummarizedCode(BaseModel):
class GraphDBType(Enum):
NETWORKX = auto()
NEO4J = auto()
FALKORDB = auto()
KUZU = auto()

View file

@ -124,5 +124,4 @@ async def add_rule_associations(
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
await index_graph_edges()
await index_graph_edges(edges_to_save)

View file

@ -4,6 +4,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.tasks.storage import index_graph_edges
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.get_default_ontology_resolver import (
@ -88,6 +89,7 @@ async def integrate_chunk_graphs(
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
await index_graph_edges(graph_edges)
return data_chunks

View file

@ -8,6 +8,7 @@ from cognee.modules.ingestion import save_data_to_file
from cognee.shared.logging_utils import get_logger
from pydantic_settings import BaseSettings, SettingsConfigDict
logger = get_logger()
@ -17,6 +18,13 @@ class SaveDataSettings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
class HTMLContent(str):
def __new__(cls, value: str):
if not ("<" in value and ">" in value):
raise ValueError("Not valid HTML-like content")
return super().__new__(cls, value)
settings = SaveDataSettings()
@ -27,6 +35,12 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
return await get_data_from_llama_index(data_item)
if "docling" in str(type(data_item)):
from docling_core.types import DoclingDocument
if isinstance(data_item, DoclingDocument):
data_item = data_item.export_to_text()
# data is a file object coming from upload.
if hasattr(data_item, "file"):
return await save_data_to_file(data_item.file, filename=data_item.filename)
@ -48,6 +62,40 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
# data is s3 file path
if parsed_url.scheme == "s3":
return data_item
elif parsed_url.scheme == "http" or parsed_url.scheme == "https":
# Validate URL by sending a HEAD request
try:
from cognee.context_global_variables import tavily_config, soup_crawler_config
from cognee.tasks.web_scraper import fetch_page_content
tavily = tavily_config.get()
soup_crawler = soup_crawler_config.get()
preferred_tool = "beautifulsoup" if soup_crawler else "tavily"
if preferred_tool == "tavily" and tavily is None:
raise IngestionError(
message="TavilyConfig must be set on the ingestion context when fetching HTTP URLs without a SoupCrawlerConfig."
)
if preferred_tool == "beautifulsoup" and soup_crawler is None:
raise IngestionError(
message="SoupCrawlerConfig must be set on the ingestion context when using the BeautifulSoup scraper."
)
data = await fetch_page_content(
data_item,
preferred_tool=preferred_tool,
tavily_config=tavily,
soup_crawler_config=soup_crawler,
)
content = ""
for key, value in data.items():
content += f"{key}:\n{value}\n\n"
return await save_data_to_file(content)
except IngestionError:
raise
except Exception as e:
raise IngestionError(
message=f"Error ingesting webpage results of url {data_item}: {str(e)}"
)
# data is local file path
elif parsed_url.scheme == "file":

View file

@ -10,9 +10,7 @@ from cognee.tasks.storage.exceptions import (
)
async def add_data_points(
data_points: List[DataPoint], update_edge_collection: bool = True
) -> List[DataPoint]:
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
@ -25,9 +23,6 @@ async def add_data_points(
Args:
data_points (List[DataPoint]):
A list of data points to process and insert into the graph.
update_edge_collection (bool, optional):
Whether to update the edge index after adding edges.
Defaults to True.
Returns:
List[DataPoint]:
@ -73,12 +68,10 @@ async def add_data_points(
graph_engine = await get_graph_engine()
await graph_engine.add_nodes(nodes)
await index_data_points(nodes)
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
if update_edge_collection:
await index_graph_edges()
await index_graph_edges(edges)
return data_points

Some files were not shown because too many files have changed in this diff Show more