merged dev into this branch

This commit is contained in:
Andrej Milicevic 2025-11-03 17:08:08 +01:00
commit d0a3bfd39f
147 changed files with 10201 additions and 6233 deletions

View file

@ -242,13 +242,14 @@ LITELLM_LOG="ERROR"
########## Local LLM via Ollama ###############################################
#LLM_API_KEY ="ollama"
#LLM_MODEL="llama3.1:8b"
#LLM_PROVIDER="ollama"
#LLM_ENDPOINT="http://localhost:11434/v1"
#EMBEDDING_PROVIDER="ollama"
#EMBEDDING_MODEL="nomic-embed-text:latest"
#EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings"
#EMBEDDING_ENDPOINT="http://localhost:11434/api/embed"
#EMBEDDING_DIMENSIONS=768
#HUGGINGFACE_TOKENIZER="nomic-ai/nomic-embed-text-v1.5"

View file

@ -123,6 +123,7 @@ jobs:
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
extra-dependencies: "scraping"
- name: Run Integration Tests
run: uv run pytest cognee/tests/integration/
@ -161,11 +162,11 @@ jobs:
runs-on: ubuntu-22.04
env:
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
BAML_LLM_PROVIDER: azure-openai
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
BAML_LLM_PROVIDER: openai
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
BAML_LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
BAML_LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
# BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}

View file

@ -7,14 +7,29 @@ on:
jobs:
docker-build-and-push:
runs-on: ubuntu-latest
runs-on:
group: Default
labels:
- docker_build_runner
steps:
- name: Check and free disk space before build
run: |
echo "=== Before cleanup ==="
df -h
echo "Removing unused preinstalled SDKs to free space..."
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc || true
docker system prune -af || true
echo "=== After cleanup ==="
df -h
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-flags: --root /tmp/buildkit
- name: Log in to Docker Hub
uses: docker/login-action@v3
@ -34,7 +49,7 @@ jobs:
- name: Build and push
id: build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
@ -45,5 +60,6 @@ jobs:
cache-from: type=registry,ref=cognee/cognee-mcp:buildcache
cache-to: type=registry,ref=cognee/cognee-mcp:buildcache,mode=max
- name: Image digest
run: echo ${{ steps.build.outputs.digest }}

View file

@ -332,6 +332,125 @@ jobs:
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py
test-entity-extraction:
name: Test Entity Extraction
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Entity Extraction Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
test-feedback-enrichment:
name: Test Feedback Enrichment
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Feedback Enrichment Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
run_conversation_sessions_test:
name: Conversation sessions test
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 3s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- name: Run Conversation session tests
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_conversation_history.py
test-load:
name: Test Load
runs-on: ubuntu-22.04

78
.github/workflows/scorecard.yml vendored Normal file
View file

@ -0,0 +1,78 @@
# This workflow uses actions that are not certified by GitHub. They are provided
# by a third-party and are governed by separate terms of service, privacy
# policy, and support documentation.
name: Scorecard supply-chain security
on:
# For Branch-Protection check. Only the default branch is supported. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
branch_protection_rule:
# To guarantee Maintained check is occasionally updated. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
schedule:
- cron: '35 8 * * 2'
push:
branches: [ "main" ]
# Declare default permissions as read only.
permissions: read-all
jobs:
analysis:
name: Scorecard analysis
runs-on: ubuntu-latest
# `publish_results: true` only works when run from the default branch. conditional can be removed if disabled.
if: github.event.repository.default_branch == github.ref_name || github.event_name == 'pull_request'
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results and get a badge (see publish_results below).
id-token: write
# Uncomment the permissions below if installing in a private repository.
# contents: read
# actions: read
steps:
- name: "Checkout code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@f49aabe0b5af0936a0987cfb85d86b75731b0186 # v2.4.1
with:
results_file: results.sarif
results_format: sarif
# (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
# - you want to enable the Branch-Protection check on a *public* repository, or
# - you are installing Scorecard on a *private* repository
# To create the PAT, follow the steps in https://github.com/ossf/scorecard-action?tab=readme-ov-file#authentication-with-fine-grained-pat-optional.
# repo_token: ${{ secrets.SCORECARD_TOKEN }}
# Public repositories:
# - Publish results to OpenSSF REST API for easy access by consumers
# - Allows the repository to include the Scorecard badge.
# - See https://github.com/ossf/scorecard-action#publishing-results.
# For private repositories:
# - `publish_results` will always be set to `false`, regardless
# of the value entered here.
publish_results: true
# (Optional) Uncomment file_mode if you have a .gitattributes with files marked export-ignore
# file_mode: git
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard (optional).
# Commenting out will disable upload of results to your repo's Code Scanning dashboard
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: results.sarif

View file

@ -34,10 +34,9 @@ jobs:
- name: Run Temporal Graph with Kuzu (lancedb + sqlite)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -73,10 +72,9 @@ jobs:
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -125,10 +123,9 @@ jobs:
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -192,10 +189,9 @@ jobs:
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}

View file

@ -10,6 +10,10 @@ on:
required: false
type: string
default: '["3.10.x", "3.12.x", "3.13.x"]'
os:
required: false
type: string
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
secrets:
LLM_PROVIDER:
required: true
@ -43,7 +47,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-15, windows-latest]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -79,7 +83,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -115,7 +119,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -151,7 +155,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -180,7 +184,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out
@ -217,7 +221,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-15, windows-latest ]
os: ${{ fromJSON(inputs.os) }}
fail-fast: false
steps:
- name: Check out

View file

@ -75,7 +75,7 @@ jobs:
{ "role": "user", "content": "Whatever I say, answer with Yes." }
]
}'
curl -X POST http://127.0.0.1:11434/v1/embeddings \
curl -X POST http://127.0.0.1:11434/api/embed \
-H "Content-Type: application/json" \
-d '{
"model": "avr/sfr-embedding-mistral:latest",
@ -98,7 +98,7 @@ jobs:
LLM_MODEL: "phi4"
EMBEDDING_PROVIDER: "ollama"
EMBEDDING_MODEL: "avr/sfr-embedding-mistral:latest"
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embeddings"
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embed"
EMBEDDING_DIMENSIONS: "4096"
HUGGINGFACE_TOKENIZER: "Salesforce/SFR-Embedding-Mistral"
run: uv run python ./examples/python/simple_example.py

View file

@ -1,4 +1,6 @@
name: Test Suites
permissions:
contents: read
on:
push:
@ -80,12 +82,22 @@ jobs:
uses: ./.github/workflows/notebooks_tests.yml
secrets: inherit
different-operating-systems-tests:
name: Operating System and Python Tests
different-os-tests-basic:
name: OS and Python Tests Ubuntu
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml
with:
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
os: '["ubuntu-22.04"]'
secrets: inherit
different-os-tests-extended:
name: OS and Python Tests Extended
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml
with:
python-versions: '["3.13.x"]'
os: '["macos-15", "windows-latest"]'
secrets: inherit
# Matrix-based vector database tests
@ -135,7 +147,8 @@ jobs:
e2e-tests,
graph-db-tests,
notebook-tests,
different-operating-systems-tests,
different-os-tests-basic,
different-os-tests-extended,
vector-db-tests,
example-tests,
llm-tests,
@ -155,7 +168,8 @@ jobs:
cli-tests,
graph-db-tests,
notebook-tests,
different-operating-systems-tests,
different-os-tests-basic,
different-os-tests-extended,
vector-db-tests,
example-tests,
db-examples-tests,
@ -176,7 +190,8 @@ jobs:
"${{ needs.cli-tests.result }}" == "success" &&
"${{ needs.graph-db-tests.result }}" == "success" &&
"${{ needs.notebook-tests.result }}" == "success" &&
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
"${{ needs.different-os-tests-basic.result }}" == "success" &&
"${{ needs.different-os-tests-extended.result }}" == "success" &&
"${{ needs.vector-db-tests.result }}" == "success" &&
"${{ needs.example-tests.result }}" == "success" &&
"${{ needs.db-examples-tests.result }}" == "success" &&

132
AGENTS.md Normal file
View file

@ -0,0 +1,132 @@
## Repository Guidelines
This document summarizes how to work with the cognee repository: how its organized, how to build, test, lint, and contribute. It mirrors our actual tooling and CI while providing quick commands for local development.
## Project Structure & Module Organization
- `cognee/`: Core Python library and API.
- `api/`: FastAPI application and versioned routers (add, cognify, memify, search, delete, users, datasets, responses, visualize, settings, sync, update, checks).
- `cli/`: CLI entry points and subcommands invoked via `cognee` / `cognee-cli`.
- `infrastructure/`: Databases, LLM providers, embeddings, loaders, and storage adapters.
- `modules/`: Domain logic (graph, retrieval, ontology, users, processing, observability, etc.).
- `tasks/`: Reusable tasks (e.g., code graph, web scraping, storage). Extend with new tasks here.
- `eval_framework/`: Evaluation utilities and adapters.
- `shared/`: Cross-cutting helpers (logging, settings, utils).
- `tests/`: Unit, integration, CLI, and end-to-end tests organized by feature.
- `__main__.py`: Entrypoint to route to CLI.
- `cognee-mcp/`: Model Context Protocol server exposing cognee as MCP tools (SSE/HTTP/stdio). Contains its own README and Dockerfile.
- `cognee-frontend/`: Next.js UI for local development and demos.
- `distributed/`: Utilities for distributed execution (Modal, workers, queues).
- `examples/`: Example scripts demonstrating the public APIs and features (graph, code graph, multimodal, permissions, etc.).
- `notebooks/`: Jupyter notebooks for demos and tutorials.
- `alembic/`: Database migrations for relational backends.
Notes:
- Co-locate feature-specific helpers under their respective package (`modules/`, `infrastructure/`, or `tasks/`).
- Extend the system by adding new tasks, loaders, or retrievers rather than modifying core pipeline mechanisms.
## Build, Test, and Development Commands
Python (root) requires Python >= 3.10 and < 3.14. We recommend `uv` for speed and reproducibility.
- Create/refresh env and install dev deps:
```bash
uv sync --dev --all-extras --reinstall
```
- Run the CLI (examples):
```bash
uv run cognee-cli add "Cognee turns documents into AI memory."
uv run cognee-cli cognify
uv run cognee-cli search "What does cognee do?"
uv run cognee-cli -ui # Launches UI, backend API, and MCP server together
```
- Start the FastAPI server directly:
```bash
uv run python -m cognee.api.client
```
- Run tests (CI mirrors these commands):
```bash
uv run pytest cognee/tests/unit/ -v
uv run pytest cognee/tests/integration/ -v
```
- Lint and format (ruff):
```bash
uv run ruff check .
uv run ruff format .
```
- Optional static type checks (mypy):
```bash
uv run mypy cognee/
```
MCP Server (`cognee-mcp/`):
- Install and run locally:
```bash
cd cognee-mcp
uv sync --dev --all-extras --reinstall
uv run python src/server.py # stdio (default)
uv run python src/server.py --transport sse
uv run python src/server.py --transport http --host 127.0.0.1 --port 8000 --path /mcp
```
- API Mode (connect to a running Cognee API):
```bash
uv run python src/server.py --transport sse --api-url http://localhost:8000 --api-token YOUR_TOKEN
```
- Docker quickstart (examples): see `cognee-mcp/README.md` for full details
```bash
docker run -e TRANSPORT_MODE=http --env-file ./.env -p 8000:8000 --rm -it cognee/cognee-mcp:main
```
Frontend (`cognee-frontend/`):
```bash
cd cognee-frontend
npm install
npm run dev # Next.js dev server
npm run lint # ESLint
npm run build && npm start
```
## Coding Style & Naming Conventions
Python:
- 4-space indentation, modules and functions in `snake_case`, classes in `PascalCase`.
- Public APIs should be type-annotated where practical.
- Use `ruff format` before committing; `ruff check` enforces import hygiene and style (line-length 100 configured in `pyproject.toml`).
- Prefer explicit, structured error handling. Use shared logging utilities in `cognee.shared.logging_utils`.
MCP server and Frontend:
- Follow the local `README.md` and ESLint/TypeScript configuration in `cognee-frontend/`.
## Testing Guidelines
- Place Python tests under `cognee/tests/`.
- Unit tests: `cognee/tests/unit/`
- Integration tests: `cognee/tests/integration/`
- CLI tests: `cognee/tests/cli_tests/`
- Name test files `test_*.py`. Use `pytest.mark.asyncio` for async tests.
- Avoid external state; rely on test fixtures and the CI-provided env vars when LLM/embedding providers are required. See CI workflows under `.github/workflows/` for expected environment variables.
- When adding public APIs, provide/update targeted examples under `examples/python/`.
## Commit & Pull Request Guidelines
- Use clear, imperative subjects (≤ 72 chars) and conventional commit styling in PR titles. Our CI validates semantic PR titles (see `.github/workflows/pr_lint`). Examples:
- `feat(graph): add temporal edge weighting`
- `fix(api): handle missing auth cookie`
- `docs: update installation instructions`
- Reference related issues/discussions in the PR body and provide brief context.
- PRs should describe scope, list local test commands run, and mention any impacts on MCP server or UI if applicable.
- Sign commits and affirm the DCO (see `CONTRIBUTING.md`).
## CI Mirrors Local Commands
Our GitHub Actions run the same ruff checks and pytest suites shown above (`.github/workflows/basic_tests.yml` and related workflows). Use the commands in this document locally to minimize CI surprises.

View file

@ -97,7 +97,7 @@ Hosted platform:
### 📦 Installation
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager.
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager..
Cognee supports Python 3.10 to 3.12

View file

@ -110,6 +110,47 @@ If you'd rather run cognee-mcp in a container, you have two options:
# For stdio transport (default)
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
```
**Installing optional dependencies at runtime:**
You can install optional dependencies when running the container by setting the `EXTRAS` environment variable:
```bash
# Install a single optional dependency group at runtime
docker run \
-e TRANSPORT_MODE=http \
-e EXTRAS=aws \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
# Install multiple optional dependency groups at runtime (comma-separated)
docker run \
-e TRANSPORT_MODE=sse \
-e EXTRAS=aws,postgres,neo4j \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
```
**Available optional dependency groups:**
- `aws` - S3 storage support
- `postgres` / `postgres-binary` - PostgreSQL database support
- `neo4j` - Neo4j graph database support
- `neptune` - AWS Neptune support
- `chromadb` - ChromaDB vector store support
- `scraping` - Web scraping capabilities
- `distributed` - Modal distributed execution
- `langchain` - LangChain integration
- `llama-index` - LlamaIndex integration
- `anthropic` - Anthropic models
- `groq` - Groq models
- `mistral` - Mistral models
- `ollama` / `huggingface` - Local model support
- `docs` - Document processing
- `codegraph` - Code analysis
- `monitoring` - Sentry & Langfuse monitoring
- `redis` - Redis support
- And more (see [pyproject.toml](https://github.com/topoteretes/cognee/blob/main/pyproject.toml) for full list)
2. **Pull from Docker Hub** (no build required):
```bash
# With HTTP transport (recommended for web deployments)
@ -119,6 +160,17 @@ If you'd rather run cognee-mcp in a container, you have two options:
# With stdio transport (default)
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
```
**With runtime installation of optional dependencies:**
```bash
# Install optional dependencies from Docker Hub image
docker run \
-e TRANSPORT_MODE=http \
-e EXTRAS=aws,postgres \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
```
### **Important: Docker vs Direct Usage**
**Docker uses environment variables**, not command line arguments:

View file

@ -4,6 +4,42 @@ set -e # Exit on error
echo "Debug mode: $DEBUG"
echo "Environment: $ENVIRONMENT"
# Install optional dependencies if EXTRAS is set
if [ -n "$EXTRAS" ]; then
echo "Installing optional dependencies: $EXTRAS"
# Get the cognee version that's currently installed
COGNEE_VERSION=$(uv pip show cognee | grep "Version:" | awk '{print $2}')
echo "Current cognee version: $COGNEE_VERSION"
# Build the extras list for cognee
IFS=',' read -ra EXTRA_ARRAY <<< "$EXTRAS"
# Combine base extras from pyproject.toml with requested extras
ALL_EXTRAS=""
for extra in "${EXTRA_ARRAY[@]}"; do
# Trim whitespace
extra=$(echo "$extra" | xargs)
# Add to extras list if not already present
if [[ ! "$ALL_EXTRAS" =~ (^|,)"$extra"(,|$) ]]; then
if [ -z "$ALL_EXTRAS" ]; then
ALL_EXTRAS="$extra"
else
ALL_EXTRAS="$ALL_EXTRAS,$extra"
fi
fi
done
echo "Installing cognee with extras: $ALL_EXTRAS"
echo "Running: uv pip install 'cognee[$ALL_EXTRAS]==$COGNEE_VERSION'"
uv pip install "cognee[$ALL_EXTRAS]==$COGNEE_VERSION"
# Verify installation
echo ""
echo "✓ Optional dependencies installation completed"
else
echo "No optional dependencies specified"
fi
# Set default transport mode if not specified
TRANSPORT_MODE=${TRANSPORT_MODE:-"stdio"}
echo "Transport mode: $TRANSPORT_MODE"

View file

@ -9,7 +9,7 @@ dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
"cognee[postgres,docs,neo4j]==0.3.7",
"fastmcp>=2.10.0,<3.0.0",
"mcp>=1.12.0,<2.0.0",
"uv>=0.6.3,<1.0.0",

View file

@ -37,12 +37,10 @@ async def run():
toolResult = await session.call_tool("prune", arguments={})
toolResult = await session.call_tool(
"codify", arguments={"repo_path": "SOME_REPO_PATH"}
)
toolResult = await session.call_tool("cognify", arguments={})
toolResult = await session.call_tool(
"search", arguments={"search_type": "CODE", "search_query": "exceptions"}
"search", arguments={"search_type": "GRAPH_COMPLETION"}
)
print(f"Cognify result: {toolResult.content}")

8575
cognee-mcp/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,5 @@
from uuid import UUID
import os
from typing import Union, BinaryIO, List, Optional, Dict, Any
from pydantic import BaseModel
from urllib.parse import urlparse
from typing import Union, BinaryIO, List, Optional, Any
from cognee.modules.users.models import User
from cognee.modules.pipelines import Task, run_pipeline
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
@ -17,16 +14,6 @@ from cognee.shared.logging_utils import get_logger
logger = get_logger()
try:
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
from cognee.context_global_variables import (
tavily_config as tavily,
soup_crawler_config as soup_crawler,
)
except ImportError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
@ -36,11 +23,8 @@ async def add(
vector_db_config: dict = None,
graph_db_config: dict = None,
dataset_id: Optional[UUID] = None,
preferred_loaders: List[str] = None,
preferred_loaders: Optional[List[Union[str, dict[str, dict[str, Any]]]]] = None,
incremental_loading: bool = True,
extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None,
data_per_batch: Optional[int] = 20,
):
"""
@ -180,28 +164,14 @@ async def add(
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
"""
try:
if not soup_crawler_config and extraction_rules:
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
if not tavily_config and os.getenv("TAVILY_API_KEY"):
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
soup_crawler.set(soup_crawler_config)
tavily.set(tavily_config)
http_schemes = {"http", "https"}
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
return isinstance(item, str) and urlparse(item).scheme in http_schemes
if _is_http_url(data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
except NameError:
logger.debug(f"Unable to import {str(ImportError)}")
pass
if preferred_loaders is not None:
transformed = {}
for item in preferred_loaders:
if isinstance(item, dict):
transformed.update(item)
else:
transformed[item] = {}
preferred_loaders = transformed
tasks = [
Task(resolve_data_directories, include_subdirectories=True),

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -63,7 +64,11 @@ def get_add_router() -> APIRouter:
send_telemetry(
"Add API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
additional_properties={
"endpoint": "POST /v1/add",
"node_set": node_set,
"cognee_version": cognee_version,
},
)
from cognee.api.v1.add import add as cognee_add

View file

@ -29,7 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -98,6 +98,7 @@ def get_cognify_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/cognify",
"cognee_version": cognee_version,
},
)

View file

@ -1,4 +1,5 @@
from uuid import UUID
from cognee.modules.data.methods import has_dataset_data
from cognee.modules.users.methods import get_default_user
from cognee.modules.ingestion import discover_directory_datasets
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
@ -26,6 +27,16 @@ class datasets:
return await get_dataset_data(dataset.id)
@staticmethod
async def has_data(dataset_id: str) -> bool:
from cognee.modules.data.methods import get_dataset
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_id)
return await has_dataset_data(dataset.id)
@staticmethod
async def get_status(dataset_ids: list[UUID]) -> dict:
return await get_pipeline_status(dataset_ids, pipeline_name="cognify_pipeline")

View file

@ -24,6 +24,7 @@ from cognee.modules.users.permissions.methods import (
from cognee.modules.graph.methods import get_formatted_graph_data
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -100,6 +101,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -147,6 +149,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -201,6 +204,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -246,6 +250,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}/data/{str(data_id)}",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)
@ -327,6 +332,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -387,6 +393,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/datasets/status",
"datasets": [str(dataset_id) for dataset_id in datasets],
"cognee_version": cognee_version,
},
)
@ -433,6 +440,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data/{str(data_id)}/raw",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -6,6 +6,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -39,6 +40,7 @@ def get_delete_router() -> APIRouter:
"endpoint": "DELETE /v1/delete",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -73,7 +74,7 @@ def get_memify_router() -> APIRouter:
send_telemetry(
"Memify API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/memify"},
additional_properties={"endpoint": "POST /v1/memify", "cognee_version": cognee_version},
)
if not payload.dataset_id and not payload.dataset_name:

View file

@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
def get_permissions_router() -> APIRouter:
@ -48,6 +49,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/datasets/{str(principal_id)}",
"dataset_ids": str(dataset_ids),
"principal_id": str(principal_id),
"cognee_version": cognee_version,
},
)
@ -89,6 +91,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/roles",
"role_name": role_name,
"cognee_version": cognee_version,
},
)
@ -133,6 +136,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/roles",
"user_id": str(user_id),
"role_id": str(role_id),
"cognee_version": cognee_version,
},
)
@ -175,6 +179,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/tenants",
"user_id": str(user_id),
"tenant_id": str(tenant_id),
"cognee_version": cognee_version,
},
)
@ -209,6 +214,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/tenants",
"tenant_name": tenant_name,
"cognee_version": cognee_version,
},
)

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
# Note: Datasets sent by name will only map to datasets owned by the request sender
@ -61,9 +62,7 @@ def get_search_router() -> APIRouter:
send_telemetry(
"Search API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /v1/search",
},
additional_properties={"endpoint": "GET /v1/search", "cognee_version": cognee_version},
)
try:
@ -118,6 +117,7 @@ def get_search_router() -> APIRouter:
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"cognee_version": cognee_version,
},
)

View file

@ -1,6 +1,7 @@
from uuid import UUID
from typing import Union, Optional, List, Type
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
@ -8,6 +9,10 @@ from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets
from cognee.modules.data.exceptions import DatasetNotFoundError
from cognee.context_global_variables import set_session_user_context_variable
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def search(
@ -25,6 +30,7 @@ async def search(
last_k: Optional[int] = 1,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -113,6 +119,8 @@ async def search(
save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
Returns:
list: Search results in format determined by query_type:
@ -168,6 +176,8 @@ async def search(
if user is None:
user = await get_default_user()
await set_session_user_context_variable(user)
# Transform string based datasets to UUID - String based datasets can only be found for current user
if datasets is not None and [all(isinstance(dataset, str) for dataset in datasets)]:
datasets = await get_authorized_existing_datasets(datasets, "read", user)
@ -189,6 +199,7 @@ async def search(
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
)
return filtered_search_results

View file

@ -12,6 +12,7 @@ from cognee.modules.sync.methods import get_running_sync_operations_for_user, ge
from cognee.shared.utils import send_telemetry
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.sync import SyncResponse
from cognee import __version__ as cognee_version
from cognee.context_global_variables import set_database_global_context_variables
logger = get_logger()
@ -99,6 +100,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/sync",
"cognee_version": cognee_version,
"dataset_ids": [str(id) for id in request.dataset_ids]
if request.dataset_ids
else "*",
@ -205,6 +207,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/sync/status",
"cognee_version": cognee_version,
},
)

View file

@ -503,7 +503,7 @@ def start_ui(
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
try:
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
image = "cognee/cognee-mcp:main"
subprocess.run(["docker", "pull", image], check=True)
import uuid
@ -538,9 +538,7 @@ def start_ui(
env_file = os.path.join(cwd, ".env")
docker_cmd.extend(["--env-file", env_file])
docker_cmd.append(
image
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
docker_cmd.append(image)
mcp_process = subprocess.Popen(
docker_cmd,

View file

@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunErrored,
)
@ -64,6 +65,7 @@ def get_update_router() -> APIRouter:
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"node_set": str(node_set),
"cognee_version": cognee_version,
},
)

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Union, BinaryIO, List, Optional
from typing import Union, BinaryIO, List, Optional, Any
from cognee.modules.users.models import User
from cognee.api.v1.delete import delete
@ -15,7 +15,7 @@ async def update(
node_set: Optional[List[str]] = None,
vector_db_config: dict = None,
graph_db_config: dict = None,
preferred_loaders: List[str] = None,
preferred_loaders: dict[str, dict[str, Any]] = None,
incremental_loading: bool = True,
):
"""

View file

@ -8,6 +8,7 @@ from cognee.modules.users.models import User
from cognee.context_global_variables import set_database_global_context_variables
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -46,6 +47,7 @@ def get_visualize_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/visualize",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)

View file

@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Optional
from functools import lru_cache
from cognee.root_dir import get_absolute_path, ensure_absolute_path
@ -11,6 +12,9 @@ class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system")
cache_root_directory: str = get_absolute_path(".cognee_cache")
logs_root_directory: str = os.getenv(
"COGNEE_LOGS_DIR", str(os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs"))
)
monitoring_tool: object = Observer.NONE
@pydantic.model_validator(mode="after")
@ -30,6 +34,8 @@ class BaseConfig(BaseSettings):
# Require absolute paths for root directories
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
self.logs_root_directory = ensure_absolute_path(self.logs_root_directory)
# Set monitoring tool based on available keys
if self.langfuse_public_key and self.langfuse_secret_key:
self.monitoring_tool = Observer.LANGFUSE
@ -49,6 +55,7 @@ class BaseConfig(BaseSettings):
"system_root_directory": self.system_root_directory,
"monitoring_tool": self.monitoring_tool,
"cache_root_directory": self.cache_root_directory,
"logs_root_directory": self.logs_root_directory,
}

View file

@ -12,8 +12,11 @@ from cognee.modules.users.methods import get_user
# for different async tasks, threads and processes
vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
tavily_config = ContextVar("tavily_config", default=None)
session_user = ContextVar("session_user", default=None)
async def set_session_user_context_variable(user):
session_user.set(user)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):

View file

@ -56,7 +56,7 @@ class CogneeValidationError(CogneeApiError):
self,
message: str = "A validation error occurred.",
name: str = "CogneeValidationError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
log=True,
log_level="ERROR",
):

View file

@ -40,3 +40,40 @@ class CacheDBInterface(ABC):
yield
finally:
self.release()
@abstractmethod
async def add_qa(
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = 86400,
):
"""
Add a Q/A/context triplet to a cache session.
"""
pass
@abstractmethod
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
"""
Retrieve the most recent Q/A/context triplets for a session.
"""
pass
@abstractmethod
async def get_all_qas(self, user_id: str, session_id: str):
"""
Retrieve all Q/A/context triplets for the given session.
"""
pass
@abstractmethod
async def close(self):
"""
Gracefully close any async connections.
"""
pass

View file

@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from functools import lru_cache
from typing import Optional
class CacheConfig(BaseSettings):
@ -18,6 +19,8 @@ class CacheConfig(BaseSettings):
shared_kuzu_lock: bool = False
cache_host: str = "localhost"
cache_port: int = 6379
cache_username: Optional[str] = None
cache_password: Optional[str] = None
agentic_lock_expire: int = 240
agentic_lock_timeout: int = 300
@ -29,6 +32,8 @@ class CacheConfig(BaseSettings):
"shared_kuzu_lock": self.shared_kuzu_lock,
"cache_host": self.cache_host,
"cache_port": self.cache_port,
"cache_username": self.cache_username,
"cache_password": self.cache_password,
"agentic_lock_expire": self.agentic_lock_expire,
"agentic_lock_timeout": self.agentic_lock_timeout,
}

View file

@ -1,8 +1,8 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
config = get_cache_config()
@ -12,6 +12,8 @@ config = get_cache_config()
def create_cache_engine(
cache_host: str,
cache_port: int,
cache_username: str,
cache_password: str,
lock_key: str,
agentic_lock_expire: int = 240,
agentic_lock_timeout: int = 300,
@ -23,6 +25,8 @@ def create_cache_engine(
-----------
- cache_host: Hostname or IP of the cache server.
- cache_port: Port number to connect to.
- cache_username: Username to authenticate with.
- cache_password: Password to authenticate with.
- lock_key: Identifier used for the locking resource.
- agentic_lock_expire: Duration to hold the lock after acquisition.
- agentic_lock_timeout: Max time to wait for the lock before failing.
@ -37,6 +41,8 @@ def create_cache_engine(
return RedisAdapter(
host=cache_host,
port=cache_port,
username=cache_username,
password=cache_password,
lock_name=lock_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
@ -45,7 +51,7 @@ def create_cache_engine(
return None
def get_cache_engine(lock_key: str) -> CacheDBInterface:
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
"""
Returns a cache adapter instance using current context configuration.
"""
@ -53,6 +59,8 @@ def get_cache_engine(lock_key: str) -> CacheDBInterface:
return create_cache_engine(
cache_host=config.cache_host,
cache_port=config.cache_port,
cache_username=config.cache_username,
cache_password=config.cache_password,
lock_key=lock_key,
agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout,

View file

@ -1,20 +1,81 @@
import asyncio
import redis
import redis.asyncio as aioredis
from contextlib import contextmanager
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
from cognee.infrastructure.databases.exceptions import CacheConnectionError
from cognee.shared.logging_utils import get_logger
from datetime import datetime
import json
logger = get_logger("RedisAdapter")
class RedisAdapter(CacheDBInterface):
def __init__(self, host, port, lock_name, timeout=240, blocking_timeout=300):
def __init__(
self,
host,
port,
lock_name="default_lock",
username=None,
password=None,
timeout=240,
blocking_timeout=300,
connection_timeout=30,
):
super().__init__(host, port, lock_name)
self.redis = redis.Redis(host=host, port=port)
self.timeout = timeout
self.blocking_timeout = blocking_timeout
self.host = host
self.port = port
self.connection_timeout = connection_timeout
try:
self.sync_redis = redis.Redis(
host=host,
port=port,
username=username,
password=password,
socket_connect_timeout=connection_timeout,
socket_timeout=connection_timeout,
)
self.async_redis = aioredis.Redis(
host=host,
port=port,
username=username,
password=password,
decode_responses=True,
socket_connect_timeout=connection_timeout,
)
self.timeout = timeout
self.blocking_timeout = blocking_timeout
# Validate connection on initialization
self._validate_connection()
logger.info(f"Successfully connected to Redis at {host}:{port}")
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Failed to connect to Redis at {host}:{port}: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error initializing Redis adapter: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
def _validate_connection(self):
"""Validate Redis connection is available."""
try:
self.sync_redis.ping()
except (redis.ConnectionError, redis.TimeoutError) as e:
raise CacheConnectionError(
f"Cannot connect to Redis at {self.host}:{self.port}: {str(e)}"
) from e
def acquire_lock(self):
"""
Acquire the Redis lock manually. Raises if acquisition fails.
Acquire the Redis lock manually. Raises if acquisition fails. (Sync because of Kuzu)
"""
self.lock = self.redis.lock(
self.lock = self.sync_redis.lock(
name=self.lock_key,
timeout=self.timeout,
blocking_timeout=self.blocking_timeout,
@ -28,7 +89,7 @@ class RedisAdapter(CacheDBInterface):
def release_lock(self):
"""
Release the Redis lock manually, if held.
Release the Redis lock manually, if held. (Sync because of Kuzu)
"""
if self.lock:
try:
@ -40,10 +101,143 @@ class RedisAdapter(CacheDBInterface):
@contextmanager
def hold_lock(self):
"""
Context manager for acquiring and releasing the Redis lock automatically.
Context manager for acquiring and releasing the Redis lock automatically. (Sync because of Kuzu)
"""
self.acquire()
try:
yield
finally:
self.release()
async def add_qa(
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = 86400,
):
"""
Add a Q/A/context triplet to a Redis list for this session.
Creates the session if it doesn't exist.
Args:
user_id (str): The user ID.
session_id: Unique identifier for the session.
question: User question text.
context: Context used to answer.
answer: Assistant answer text.
ttl: Optional time-to-live (seconds). If provided, the session expires after this time.
Raises:
CacheConnectionError: If Redis connection fails or times out.
"""
try:
session_key = f"agent_sessions:{user_id}:{session_id}"
qa_entry = {
"time": datetime.utcnow().isoformat(),
"question": question,
"context": context,
"answer": answer,
}
await self.async_redis.rpush(session_key, json.dumps(qa_entry))
if ttl is not None:
await self.async_redis.expire(session_key, ttl)
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while adding Q&A: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while adding Q&A to Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
"""
Retrieve the most recent Q/A/context triplet(s) for the given session.
"""
session_key = f"agent_sessions:{user_id}:{session_id}"
if last_n == 1:
data = await self.async_redis.lindex(session_key, -1)
return [json.loads(data)] if data else None
else:
data = await self.async_redis.lrange(session_key, -last_n, -1)
return [json.loads(d) for d in data] if data else []
async def get_all_qas(self, user_id: str, session_id: str):
"""
Retrieve all Q/A/context triplets for the given session.
"""
session_key = f"agent_sessions:{user_id}:{session_id}"
entries = await self.async_redis.lrange(session_key, 0, -1)
return [json.loads(e) for e in entries]
async def close(self):
"""
Gracefully close the async Redis connection.
"""
await self.async_redis.aclose()
async def main():
HOST = "localhost"
PORT = 6379
adapter = RedisAdapter(host=HOST, port=PORT)
session_id = "demo_session"
user_id = "demo_user_id"
print("\nAdding sample Q/A pairs...")
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Basic DB context",
"Redis is an in-memory data store.",
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"Historical context",
"Salvatore Sanfilippo (antirez).",
)
print("\nLatest QA:")
latest = await adapter.get_latest_qa(user_id, session_id)
print(json.dumps(latest, indent=2))
print("\nLast 2 QAs:")
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
print(json.dumps(last_two, indent=2))
session_id = "session_expire_demo"
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Database context",
"Redis is an in-memory data store.",
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"History context",
"Salvatore Sanfilippo (antirez).",
)
print(await adapter.get_all_qas(user_id, session_id))
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -11,4 +11,5 @@ from .exceptions import (
EmbeddingException,
MissingQueryParameterError,
MutuallyExclusiveQueryParametersError,
CacheConnectionError,
)

View file

@ -15,7 +15,7 @@ class DatabaseNotCreatedError(CogneeSystemError):
self,
message: str = "The database has not been created yet. Please call `await setup()` first.",
name: str = "DatabaseNotCreatedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)
@ -99,7 +99,7 @@ class EmbeddingException(CogneeConfigurationError):
self,
message: str = "Embedding Exception.",
name: str = "EmbeddingException",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)
@ -132,3 +132,19 @@ class MutuallyExclusiveQueryParametersError(CogneeValidationError):
):
message = "The search function accepts either text or embedding as input, but not both."
super().__init__(message, name, status_code)
class CacheConnectionError(CogneeConfigurationError):
"""
Raised when connection to the cache database (e.g., Redis) fails.
This error indicates that the cache service is unavailable or misconfigured.
"""
def __init__(
self,
message: str = "Failed to connect to cache database. Please check your cache configuration.",
name: str = "CacheConnectionError",
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
):
super().__init__(message, name, status_code)

View file

@ -159,6 +159,11 @@ class GraphDBInterface(ABC):
- get_connections
"""
@abstractmethod
async def is_empty(self) -> bool:
logger.warning("is_empty() is not implemented")
return True
@abstractmethod
async def query(self, query: str, params: dict) -> List[Any]:
"""

View file

@ -198,6 +198,15 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def is_empty(self) -> bool:
query = """
MATCH (n)
RETURN true
LIMIT 1;
"""
query_result = await self.query(query)
return len(query_result) == 0
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -1357,9 +1366,15 @@ class KuzuAdapter(GraphDBInterface):
params[param_name] = values
where_clause = " AND ".join(where_clauses)
nodes_query = (
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
)
nodes_query = f"""
MATCH (n:Node)
WHERE {where_clause}
RETURN n.id, {{
name: n.name,
type: n.type,
properties: n.properties
}}
"""
edges_query = f"""
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}

View file

@ -87,6 +87,15 @@ class Neo4jAdapter(GraphDBInterface):
async with self.driver.session(database=self.graph_database_name) as session:
yield session
async def is_empty(self) -> bool:
query = """
RETURN EXISTS {
MATCH (n)
} AS node_exists;
"""
query_result = await self.query(query)
return not query_result[0]["node_exists"]
@deadlock_retry()
async def query(
self,

View file

@ -47,7 +47,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
if vector_db_provider == "pgvector":
if vector_db_provider.lower() == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
@ -78,7 +78,7 @@ def create_vector_engine(
embedding_engine,
)
elif vector_db_provider == "chromadb":
elif vector_db_provider.lower() == "chromadb":
try:
import chromadb
except ImportError:
@ -94,7 +94,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
elif vector_db_provider == "neptune_analytics":
elif vector_db_provider.lower() == "neptune_analytics":
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
@ -122,7 +122,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
else:
elif vector_db_provider.lower() == "lancedb":
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
@ -130,3 +130,9 @@ def create_vector_engine(
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
else:
raise EnvironmentError(
f"Unsupported vector database provider: {vector_db_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
)

View file

@ -124,7 +124,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
return data["embedding"]
return data["embeddings"][0]
def get_vector_size(self) -> int:
"""

View file

@ -15,7 +15,7 @@ class CollectionNotFoundError(CogneeValidationError):
self,
message,
name: str = "CollectionNotFoundError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
log=True,
log_level="DEBUG",
):

View file

@ -324,7 +324,6 @@ class LanceDBAdapter(VectorDBInterface):
def get_data_point_schema(self, model_type: BaseModel):
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():
if hasattr(field_config, "model_fields"):
related_models_fields.append(field_name)

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict
@ -18,9 +18,21 @@ class Edge(BaseModel):
# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
# With edge_text for rich embedding representation
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""
weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
edge_text: Optional[str] = None
@field_validator("edge_text", mode="before")
@classmethod
def ensure_edge_text(cls, v, info):
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
if v is None and info.data.get("relationship_type"):
return info.data["relationship_type"]
return v

View file

@ -8,6 +8,6 @@ class FileContentHashingError(Exception):
self,
message: str = "Failed to hash content of the file.",
name: str = "FileContentHashingError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -82,16 +82,16 @@ class LocalFileStorage(Storage):
self.ensure_directory_exists(file_dir_path)
if overwrite or not os.path.exists(full_file_path):
with open(
full_file_path,
mode="w" if isinstance(data, str) else "wb",
encoding="utf-8" if isinstance(data, str) else None,
) as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
if isinstance(data, str):
with open(full_file_path, mode="w", encoding="utf-8", newline="\n") as file:
file.write(data)
else:
with open(full_file_path, mode="wb") as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
file.write(data)
file.close()

View file

@ -70,18 +70,18 @@ class S3FileStorage(Storage):
if overwrite or not await self.file_exists(file_path):
def save_data_to_file():
with self.s3.open(
full_file_path,
mode="w" if isinstance(data, str) else "wb",
encoding="utf-8" if isinstance(data, str) else None,
) as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
if isinstance(data, str):
with self.s3.open(
full_file_path, mode="w", encoding="utf-8", newline="\n"
) as file:
file.write(data)
file.close()
else:
with self.s3.open(full_file_path, mode="wb") as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
file.write(data)
await run_async(save_data_to_file)

View file

@ -1,6 +1,6 @@
import io
import os.path
from typing import BinaryIO, TypedDict
from typing import BinaryIO, TypedDict, Optional
from pathlib import Path
from cognee.shared.logging_utils import get_logger
@ -27,7 +27,7 @@ class FileMetadata(TypedDict):
file_size: int
async def get_file_metadata(file: BinaryIO) -> FileMetadata:
async def get_file_metadata(file: BinaryIO, name: Optional[str] = None) -> FileMetadata:
"""
Retrieve metadata from a file object.
@ -53,7 +53,7 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata:
except io.UnsupportedOperation as error:
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")
file_type = guess_file_type(file)
file_type = guess_file_type(file, name)
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)

View file

@ -1,6 +1,9 @@
from typing import BinaryIO
import io
from pathlib import Path
from typing import BinaryIO, Optional, Any
import filetype
from .is_text_content import is_text_content
from tempfile import SpooledTemporaryFile
from filetype.types.base import Type
class FileTypeException(Exception):
@ -22,90 +25,7 @@ class FileTypeException(Exception):
self.message = message
class TxtFileType(filetype.Type):
"""
Represents a text file type with specific MIME and extension properties.
Public methods:
- match: Determines whether a given buffer matches the text file type.
"""
MIME = "text/plain"
EXTENSION = "txt"
def __init__(self):
super(TxtFileType, self).__init__(mime=TxtFileType.MIME, extension=TxtFileType.EXTENSION)
def match(self, buf):
"""
Determine if the given buffer contains text content.
Parameters:
-----------
- buf: The buffer to check for text content.
Returns:
--------
Returns True if the buffer is identified as text content, otherwise False.
"""
return is_text_content(buf)
txt_file_type = TxtFileType()
filetype.add_type(txt_file_type)
class CustomPdfMatcher(filetype.Type):
"""
Match PDF file types based on MIME type and extension.
Public methods:
- match
Instance variables:
- MIME: The MIME type of the PDF.
- EXTENSION: The file extension of the PDF.
"""
MIME = "application/pdf"
EXTENSION = "pdf"
def __init__(self):
super(CustomPdfMatcher, self).__init__(
mime=CustomPdfMatcher.MIME, extension=CustomPdfMatcher.EXTENSION
)
def match(self, buf):
"""
Determine if the provided buffer is a PDF file.
This method checks for the presence of the PDF signature in the buffer.
Raises:
- TypeError: If the buffer is not of bytes type.
Parameters:
-----------
- buf: The buffer containing the data to be checked.
Returns:
--------
Returns True if the buffer contains a PDF signature, otherwise returns False.
"""
return b"PDF-" in buf
custom_pdf_matcher = CustomPdfMatcher()
filetype.add_type(custom_pdf_matcher)
def guess_file_type(file: BinaryIO) -> filetype.Type:
def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type:
"""
Guess the file type from the given binary file stream.
@ -122,12 +42,23 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
- filetype.Type: The guessed file type, represented as filetype.Type.
"""
# Note: If file has .txt or .text extension, consider it a plain text file as filetype.guess may not detect it properly
# as it contains no magic number encoding
ext = None
if isinstance(file, str):
ext = Path(file).suffix
elif name is not None:
ext = Path(name).suffix
if ext in [".txt", ".text"]:
file_type = Type("text/plain", "txt")
return file_type
file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
if file_type is None:
from filetype.types.base import Type
file_type = Type("text/plain", "txt")
if file_type is None:

View file

@ -1,15 +1,13 @@
For the purposes of identifying timestamps in a query, you are tasked with extracting relevant timestamps from the query.
## Timestamp requirements
- If the query contains interval extrack both starts_at and ends_at properties
- If the query contains an instantaneous timestamp, starts_at and ends_at should be the same
- If the query its open-ended (before 2009 or after 2009), the corresponding non defined end of the time should be none
-For example: "before 2009" -- starts_at: None, ends_at: 2009 or "after 2009" -- starts_at: 2009, ends_at: None
- Put always the data that comes first in time as starts_at and the timestamps that comes second in time as ends_at
- If starts_at or ends_at cannot be extracted both of them has to be None
## Output Format
Your reply should be a JSON: list of dictionaries with the following structure:
```python
class QueryInterval(BaseModel):
starts_at: Optional[Timestamp] = None
ends_at: Optional[Timestamp] = None
```
You are tasked with identifying relevant time periods where the answer to a given query should be searched.
Current date is: `{{ time_now }}`. Determine relevant period(s) and return structured intervals.
Extraction rules:
1. Query without specific timestamp: use the time period with starts_at set to None and ends_at set to now.
2. Explicit time intervals: If the query specifies a range (e.g., from 2010 to 2020, between January and March 2023), extract both start and end dates. Always assign the earlier date to starts_at and the later date to ends_at.
3. Single timestamp: If the query refers to one specific moment (e.g., in 2015, on March 5, 2022), set starts_at and ends_at to that same timestamp.
4. Open-ended time references: For phrases such as "before X" or "after X", represent the unspecified side as None. For example: before 2009 → starts_at: None, ends_at: 2009; after 2009 → starts_at: 2009, ends_at: None.
5. Current-time references ("now", "current", "today"): If the query explicitly refers to the present, set both starts_at and ends_at to now (the ingestion timestamp).
6. "Who is" and "Who was" questions: These imply a general identity or biographical inquiry without a specific temporal scope. Set both starts_at and ends_at to None.
7. Ordering rule: Always ensure the earlier date is assigned to starts_at and the later date to ends_at.
8. No temporal information: If no valid or inferable time reference is found, set both starts_at and ends_at to None.

View file

@ -0,0 +1,14 @@
A question was previously answered, but the answer received negative feedback.
Please reconsider and improve the response.
Question: {question}
Context originally used: {context}
Previous answer: {wrong_answer}
Feedback on that answer: {negative_feedback}
Task: Provide a better response. The new answer should be short and direct.
Then explain briefly why this answer is better.
Format your reply as:
Answer: <improved answer>
Explanation: <short explanation>

View file

@ -0,0 +1,13 @@
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
The paragraph should read naturally on its own, providing all necessary context and reasoning
so the answer is clear and well-supported.
Question: {question}
Correct answer: {improved_answer}
Supporting context: {new_context}
Your paragraph should:
- First sentence clearly states the correct answer as a full sentence
- Remainder flows from first sentence and provides explanation based on context
- Use simple, direct language that is easy to follow
- Use shorter sentences, no long-winded explanations

View file

@ -0,0 +1,5 @@
Question: {question}
Context: {context}
Provide a one paragraph human readable summary of this interaction context,
listing all the relevant facts and information in a simple and direct way.

View file

@ -1,6 +1,7 @@
import filetype
from typing import Dict, List, Optional, Any
from .LoaderInterface import LoaderInterface
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type
from cognee.shared.logging_utils import get_logger
logger = get_logger(__name__)
@ -64,7 +65,9 @@ class LoaderEngine:
return True
def get_loader(
self, file_path: str, preferred_loaders: List[str] = None
self,
file_path: str,
preferred_loaders: dict[str, dict[str, Any]],
) -> Optional[LoaderInterface]:
"""
Get appropriate loader for a file.
@ -76,14 +79,21 @@ class LoaderEngine:
Returns:
LoaderInterface that can handle the file, or None if not found
"""
from pathlib import Path
file_info = filetype.guess(file_path)
file_info = guess_file_type(file_path)
path_extension = Path(file_path).suffix.lstrip(".")
# Try preferred loaders first
if preferred_loaders:
for loader_name in preferred_loaders:
if loader_name in self._loaders:
loader = self._loaders[loader_name]
# Try with path extension first (for text formats like html)
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
return loader
# Fall back to content-detected extension
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
@ -93,6 +103,10 @@ class LoaderEngine:
for loader_name in self.default_loader_priority:
if loader_name in self._loaders:
loader = self._loaders[loader_name]
# Try with path extension first (for text formats like html)
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
return loader
# Fall back to content-detected extension
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
@ -105,7 +119,7 @@ class LoaderEngine:
async def load_file(
self,
file_path: str,
preferred_loaders: Optional[List[str]] = None,
preferred_loaders: dict[str, dict[str, Any]] = None,
**kwargs,
):
"""
@ -113,7 +127,7 @@ class LoaderEngine:
Args:
file_path: Path to the file to be processed
preferred_loaders: List of preferred loader names to try first
preferred_loaders: Dict of loader names to their configurations
**kwargs: Additional loader-specific configuration
Raises:
@ -125,8 +139,16 @@ class LoaderEngine:
raise ValueError(f"No loader found for file: {file_path}")
logger.debug(f"Loading {file_path} with {loader.loader_name}")
# TODO: loading needs to be reworked to work with both file streams and file locations
return await loader.load(file_path, **kwargs)
# Extract loader-specific config from preferred_loaders
loader_config = {}
if preferred_loaders and loader.loader_name in preferred_loaders:
loader_config = preferred_loaders[loader.loader_name]
# Merge with any additional kwargs (kwargs take precedence)
merged_kwargs = {**loader_config, **kwargs}
return await loader.load(file_path, **merged_kwargs)
def get_available_loaders(self) -> List[str]:
"""

View file

@ -42,6 +42,7 @@ class AudioLoader(LoaderInterface):
"audio/wav",
"audio/amr",
"audio/aiff",
"audio/x-wav",
]
@property

View file

@ -27,3 +27,10 @@ try:
__all__.append("AdvancedPdfLoader")
except ImportError:
pass
try:
from .beautiful_soup_loader import BeautifulSoupLoader
__all__.append("BeautifulSoupLoader")
except ImportError:
pass

View file

@ -0,0 +1,310 @@
"""BeautifulSoup-based web crawler for extracting content from web pages.
This module provides the BeautifulSoupCrawler class for fetching and extracting content
from web pages using BeautifulSoup or Playwright for JavaScript-rendered pages. It
supports robots.txt handling, rate limiting, and custom extraction rules.
"""
from typing import Union, Dict, Any, Optional, List
from dataclasses import dataclass
from bs4 import BeautifulSoup
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.shared.logging_utils import get_logger
logger = get_logger(__name__)
@dataclass
class ExtractionRule:
"""Normalized extraction rule for web content.
Attributes:
selector: CSS selector for extraction (if any).
xpath: XPath expression for extraction (if any).
attr: HTML attribute to extract (if any).
all: If True, extract all matching elements; otherwise, extract first.
join_with: String to join multiple extracted elements.
"""
selector: Optional[str] = None
xpath: Optional[str] = None
attr: Optional[str] = None
all: bool = False
join_with: str = " "
class BeautifulSoupLoader(LoaderInterface):
"""Crawler for fetching and extracting web content using BeautifulSoup.
Supports asynchronous HTTP requests, Playwright for JavaScript rendering, robots.txt
compliance, and rate limiting. Extracts content using CSS selectors or XPath rules.
Attributes:
concurrency: Number of concurrent requests allowed.
crawl_delay: Minimum seconds between requests to the same domain.
max_crawl_delay: Maximum crawl delay to respect from robots.txt (None = no limit).
timeout: Per-request timeout in seconds.
max_retries: Number of retries for failed requests.
retry_delay_factor: Multiplier for exponential backoff on retries.
headers: HTTP headers for requests (e.g., User-Agent).
robots_cache_ttl: Time-to-live for robots.txt cache in seconds.
"""
@property
def supported_extensions(self) -> List[str]:
return ["html"]
@property
def supported_mime_types(self) -> List[str]:
return ["text/html", "text/plain"]
@property
def loader_name(self) -> str:
return "beautiful_soup_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
can = extension in self.supported_extensions and mime_type in self.supported_mime_types
return can
def _get_default_extraction_rules(self):
# Comprehensive default extraction rules for common HTML content
return {
# Meta information
"title": {"selector": "title", "all": False},
"meta_description": {
"selector": "meta[name='description']",
"attr": "content",
"all": False,
},
"meta_keywords": {
"selector": "meta[name='keywords']",
"attr": "content",
"all": False,
},
# Open Graph meta tags
"og_title": {
"selector": "meta[property='og:title']",
"attr": "content",
"all": False,
},
"og_description": {
"selector": "meta[property='og:description']",
"attr": "content",
"all": False,
},
# Main content areas (prioritized selectors)
"article": {"selector": "article", "all": True, "join_with": "\n\n"},
"main": {"selector": "main", "all": True, "join_with": "\n\n"},
# Semantic content sections
"headers_h1": {"selector": "h1", "all": True, "join_with": "\n"},
"headers_h2": {"selector": "h2", "all": True, "join_with": "\n"},
"headers_h3": {"selector": "h3", "all": True, "join_with": "\n"},
"headers_h4": {"selector": "h4", "all": True, "join_with": "\n"},
"headers_h5": {"selector": "h5", "all": True, "join_with": "\n"},
"headers_h6": {"selector": "h6", "all": True, "join_with": "\n"},
# Text content
"paragraphs": {"selector": "p", "all": True, "join_with": "\n\n"},
"blockquotes": {"selector": "blockquote", "all": True, "join_with": "\n\n"},
"preformatted": {"selector": "pre", "all": True, "join_with": "\n\n"},
# Lists
"ordered_lists": {"selector": "ol", "all": True, "join_with": "\n"},
"unordered_lists": {"selector": "ul", "all": True, "join_with": "\n"},
"list_items": {"selector": "li", "all": True, "join_with": "\n"},
"definition_lists": {"selector": "dl", "all": True, "join_with": "\n"},
# Tables
"tables": {"selector": "table", "all": True, "join_with": "\n\n"},
"table_captions": {
"selector": "caption",
"all": True,
"join_with": "\n",
},
# Code blocks
"code_blocks": {"selector": "code", "all": True, "join_with": "\n"},
# Figures and media descriptions
"figures": {"selector": "figure", "all": True, "join_with": "\n\n"},
"figcaptions": {"selector": "figcaption", "all": True, "join_with": "\n"},
"image_alts": {"selector": "img", "attr": "alt", "all": True, "join_with": " "},
# Links (text content, not URLs to avoid clutter)
"link_text": {"selector": "a", "all": True, "join_with": " "},
# Emphasized text
"strong": {"selector": "strong", "all": True, "join_with": " "},
"emphasis": {"selector": "em", "all": True, "join_with": " "},
"marked": {"selector": "mark", "all": True, "join_with": " "},
# Time and data elements
"time": {"selector": "time", "all": True, "join_with": " "},
"data": {"selector": "data", "all": True, "join_with": " "},
# Sections and semantic structure
"sections": {"selector": "section", "all": True, "join_with": "\n\n"},
"asides": {"selector": "aside", "all": True, "join_with": "\n\n"},
"details": {"selector": "details", "all": True, "join_with": "\n"},
"summary": {"selector": "summary", "all": True, "join_with": "\n"},
# Navigation (may contain important links/structure)
"nav": {"selector": "nav", "all": True, "join_with": "\n"},
# Footer information
"footer": {"selector": "footer", "all": True, "join_with": "\n"},
# Divs with specific content roles
"content_divs": {
"selector": "div[role='main'], div[role='article'], div.content, div#content",
"all": True,
"join_with": "\n\n",
},
}
async def load(
self,
file_path: str,
extraction_rules: dict[str, Any] = None,
join_all_matches: bool = False,
**kwargs,
):
"""Load an HTML file, extract content, and save to storage.
Args:
file_path: Path to the HTML file
extraction_rules: Dict of CSS selector rules for content extraction
join_all_matches: If True, extract all matching elements for each rule
**kwargs: Additional arguments
Returns:
Path to the stored extracted text file
"""
if extraction_rules is None:
extraction_rules = self._get_default_extraction_rules()
logger.info("Using default comprehensive extraction rules for HTML content")
logger.info(f"Processing HTML file: {file_path}")
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
f.seek(0)
html = f.read()
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
# Normalize extraction rules
normalized_rules: List[ExtractionRule] = []
for _, rule in extraction_rules.items():
r = self._normalize_rule(rule)
if join_all_matches:
r.all = True
normalized_rules.append(r)
pieces = []
for rule in normalized_rules:
text = self._extract_from_html(html, rule)
if text:
pieces.append(text)
full_content = " ".join(pieces).strip()
# remove after defaults for extraction rules
# Fallback: If no content extracted, check if the file is plain text (not HTML)
if not full_content:
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
# If there are no HTML tags, treat as plain text
if not soup.find():
logger.warning(
f"No HTML tags found in {file_path}. Treating as plain text. "
"This may happen when content is pre-extracted (e.g., via Tavily with text format)."
)
full_content = html.decode("utf-8") if isinstance(html, bytes) else html
full_content = full_content.strip()
if not full_content:
logger.warning(f"No content extracted from HTML file: {file_path}")
# Store the extracted content
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, full_content)
logger.info(f"Extracted {len(full_content)} characters from HTML")
return full_file_path
def _normalize_rule(self, rule: Union[str, Dict[str, Any]]) -> ExtractionRule:
"""Normalize an extraction rule to an ExtractionRule dataclass.
Args:
rule: A string (CSS selector) or dict with extraction parameters.
Returns:
ExtractionRule: Normalized extraction rule.
Raises:
ValueError: If the rule is invalid.
"""
if isinstance(rule, str):
return ExtractionRule(selector=rule)
if isinstance(rule, dict):
return ExtractionRule(
selector=rule.get("selector"),
xpath=rule.get("xpath"),
attr=rule.get("attr"),
all=bool(rule.get("all", False)),
join_with=rule.get("join_with", " "),
)
raise ValueError(f"Invalid extraction rule: {rule}")
def _extract_from_html(self, html: str, rule: ExtractionRule) -> str:
"""Extract content from HTML using BeautifulSoup or lxml XPath.
Args:
html: The HTML content to extract from.
rule: The extraction rule to apply.
Returns:
str: The extracted content.
Raises:
RuntimeError: If XPath is used but lxml is not installed.
"""
soup = BeautifulSoup(html, "html.parser")
if rule.xpath:
try:
from lxml import html as lxml_html
except ImportError:
raise RuntimeError(
"XPath requested but lxml is not available. Install lxml or use CSS selectors."
)
doc = lxml_html.fromstring(html)
nodes = doc.xpath(rule.xpath)
texts = []
for n in nodes:
if hasattr(n, "text_content"):
texts.append(n.text_content().strip())
else:
texts.append(str(n).strip())
return rule.join_with.join(t for t in texts if t)
if not rule.selector:
return ""
if rule.all:
nodes = soup.select(rule.selector)
pieces = []
for el in nodes:
if rule.attr:
val = el.get(rule.attr)
if val:
pieces.append(val.strip())
else:
text = el.get_text(strip=True)
if text:
pieces.append(text)
return rule.join_with.join(pieces).strip()
else:
el = soup.select_one(rule.selector)
if el is None:
return ""
if rule.attr:
val = el.get(rule.attr)
return (val or "").strip()
return el.get_text(strip=True)

View file

@ -23,3 +23,10 @@ try:
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
except ImportError:
pass
try:
from cognee.infrastructure.loaders.external import BeautifulSoupLoader
supported_loaders[BeautifulSoupLoader.loader_name] = BeautifulSoupLoader
except ImportError:
pass

View file

@ -1,6 +1,7 @@
from typing import List, Union
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Union[Entity, Event]] = None
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
metadata: dict = {"index_fields": ["text"]}

View file

@ -10,7 +10,7 @@ class UnstructuredLibraryImportError(CogneeConfigurationError):
self,
message: str = "Import error. Unstructured library is not installed.",
name: str = "UnstructuredModuleImportError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -23,3 +23,6 @@ from .create_authorized_dataset import create_authorized_dataset
# Check
from .check_dataset_name import check_dataset_name
# Boolean check
from .has_dataset_data import has_dataset_data

View file

@ -0,0 +1,21 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import DatasetData
async def has_dataset_data(dataset_id: UUID) -> bool:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id == dataset_id)
)
count = await session.execute(count_query)
return count.scalar_one() > 0

View file

@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph):
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
distance = embedding_map.get(relationship_type, None)
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance

View file

@ -1,5 +1,6 @@
from typing import Optional
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)
# Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)
edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_node.name}",
f"entity_description: {entity_node.description}",
]
)
data_chunk.contains.append(
(
Edge(
relationship_type="contains",
edge_text=edge_text,
),
entity_node,
)
)
def _process_graph_edges(

View file

@ -1,71 +1,70 @@
import string
from typing import List
from collections import Counter
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
def _get_top_n_frequent_words(
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
) -> str:
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id in nodes:
continue
text = node.attributes.get("text")
if text:
name = _create_title_from_text(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""
Converts retrieved graph edges into a human-readable string format.
"""Converts retrieved graph edges into a human-readable string format."""
nodes = _extract_nodes_from_edges(retrieved_edges)
Parameters:
-----------
- retrieved_edges (list): A list of edges retrieved from the graph.
Returns:
--------
- str: A formatted string representation of the nodes and their connections.
"""
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
stop_words = DEFAULT_STOP_WORDS
import string
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
from collections import Counter
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
"""Creates a title, by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _top_n_words(text, top_n=first_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = _get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
nodes = _get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
connections = []
for edge in retrieved_edges:
source_name = nodes[edge.node1.id]["name"]
target_name = nodes[edge.node2.id]["name"]
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
connection_section = "\n".join(connections)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"

View file

@ -30,7 +30,7 @@ class BinaryData(IngestionData):
async def ensure_metadata(self):
if self.metadata is None:
self.metadata = await get_file_metadata(self.data)
self.metadata = await get_file_metadata(self.data, name=self.name)
if self.metadata["name"] is None:
self.metadata["name"] = self.name

View file

@ -1,10 +1,12 @@
from typing import BinaryIO, Union
from typing import BinaryIO, Union, Optional
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from .classify import classify
import hashlib
async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
async def save_data_to_file(
data: Union[str, BinaryIO], filename: str = None, file_extension: Optional[str] = None
):
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
@ -21,6 +23,11 @@ async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
file_name = file_metadata["name"]
if file_extension is not None:
extension = file_extension.lstrip(".")
file_name_without_ext = file_name.rsplit(".", 1)[0]
file_name = f"{file_name_without_ext}.{extension}"
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(file_name, data)

View file

@ -21,7 +21,8 @@ def get_ontology_resolver_from_env(
Supported value: "rdflib".
matching_strategy (str): The matching strategy to apply.
Supported value: "fuzzy".
ontology_file_path (str): Path to the ontology file required for the resolver.
ontology_file_path (str): Path to the ontology file(s) required for the resolver.
Can be a single path or comma-separated paths for multiple files.
Returns:
BaseOntologyResolver: An instance of the requested ontology resolver.
@ -31,8 +32,13 @@ def get_ontology_resolver_from_env(
or if required parameters are missing.
"""
if ontology_resolver == "rdflib" and matching_strategy == "fuzzy" and ontology_file_path:
if "," in ontology_file_path:
file_paths = [path.strip() for path in ontology_file_path.split(",")]
else:
file_paths = ontology_file_path
return RDFLibOntologyResolver(
matching_strategy=FuzzyMatchingStrategy(), ontology_file=ontology_file_path
matching_strategy=FuzzyMatchingStrategy(), ontology_file=file_paths
)
else:
raise EnvironmentError(

View file

@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
from typing import List, Tuple, Dict, Optional, Any
from typing import List, Tuple, Dict, Optional, Any, Union
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@ -26,22 +26,50 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
ontology_file: Optional[str] = None,
ontology_file: Optional[Union[str, List[str]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
if ontology_file and os.path.exists(ontology_file):
files_to_load = []
if ontology_file is not None:
if isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
files_to_load = ontology_file
else:
raise ValueError(
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
)
if files_to_load:
self.graph = Graph()
self.graph.parse(ontology_file)
logger.info("Ontology loaded successfully from file: %s", ontology_file)
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
)
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
logger.info(
"Ontology file '%s' not found. No owl ontology will be attached to the graph.",
ontology_file,
"No ontology file provided. No owl ontology will be attached to the graph."
)
self.graph = None
self.build_lookup()
except Exception as e:
logger.error("Failed to load ontology", exc_info=e)

View file

@ -7,6 +7,6 @@ class PipelineRunFailedError(CogneeSystemError):
self,
message: str = "Pipeline run failed.",
name: str = "PipelineRunFailedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -20,6 +20,7 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
check_pipeline_run_qualification,
)
from typing import Any
logger = get_logger("cognee.pipeline")
@ -80,7 +81,14 @@ async def run_pipeline_per_dataset(
return
pipeline_run = run_tasks(
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch
tasks,
dataset.id,
data,
user,
pipeline_name,
context,
incremental_loading,
data_per_batch,
)
async for pipeline_run_info in pipeline_run:

View file

@ -2,6 +2,7 @@ import inspect
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from ..tasks.task import Task
@ -25,6 +26,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
@ -46,6 +49,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
except Exception as error:
@ -58,6 +63,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
raise error

View file

@ -4,6 +4,7 @@ from cognee.modules.settings import get_current_settings
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from .run_tasks_base import run_tasks_base
from ..tasks.task import Task
@ -26,6 +27,8 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)
@ -39,7 +42,10 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
},
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)
except Exception as error:
logger.error(
@ -53,6 +59,8 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)

View file

@ -1,10 +1,17 @@
import asyncio
from typing import Any, Optional, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("entity_completion_retriever")
@ -77,7 +84,9 @@ class EntityCompletionRetriever(BaseRetriever):
logger.error(f"Context retrieval failed: {str(e)}")
return None
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
"""
Generate completion using provided context or fetch new context.
@ -91,6 +100,8 @@ class EntityCompletionRetriever(BaseRetriever):
- query (str): The query string for which completion is being generated.
- context (Optional[Any]): Optional context to be used for generating completion;
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
@ -105,12 +116,41 @@ class EntityCompletionRetriever(BaseRetriever):
if context is None:
return ["No relevant entities found for the query."]
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(str(context)),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]
except Exception as e:

View file

@ -13,6 +13,8 @@ class BaseGraphRetriever(ABC):
pass
@abstractmethod
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -11,6 +11,8 @@ class BaseRetriever(ABC):
pass
@abstractmethod
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""Generates a response using the query and optional context."""
pass

View file

@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Generates a completion using document chunks context.
@ -74,6 +76,8 @@ class ChunksRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -207,8 +207,26 @@ class CodeRetriever(BaseRetriever):
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the code files context."""
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns the code files context.
Parameters:
-----------
- query (str): The query string to retrieve code context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The code files context, either provided or retrieved.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -1,11 +1,18 @@
import asyncio
from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("CompletionRetriever")
@ -67,7 +74,9 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
"""
Generates an LLM completion using the context.
@ -80,6 +89,8 @@ class CompletionRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
@ -89,11 +100,41 @@ class CompletionRetriever(BaseRetriever):
if context is None:
context = await self.get_context(query)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return completion

View file

@ -44,13 +44,21 @@ class CypherSearchRetriever(BaseRetriever):
"""
try:
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
if is_empty:
logger.warning("Search attempt on an empty knowledge graph")
return []
result = await graph_engine.query(query)
except Exception as e:
logger.error("Failed to execture cypher search retrieval: %s", str(e))
raise CypherSearchError() from e
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns the graph connections context.
@ -62,6 +70,8 @@ class CypherSearchRetriever(BaseRetriever):
- query (str): The query to retrieve context.
- context (Optional[Any]): Optional context to use, otherwise fetched using the
query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -1,8 +1,15 @@
import asyncio
from typing import Optional, List, Type
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
@ -47,6 +54,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
) -> List[str]:
"""
@ -64,6 +72,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- query (str): The input query for which the completion is generated.
- context (Optional[Any]): The existing context to use for enhancing the query; if
None, it will be initialized from triplets generated for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
@ -115,17 +125,46 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context_text),
generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
if self.save_interaction and context_text and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]

View file

@ -1,15 +1,41 @@
import asyncio
import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
@ -18,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -54,33 +81,30 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path
async def get_completion(
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
max_iter=4,
) -> List[str]:
conversation_history: str = "",
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Run chain-of-thought completion with optional structured output.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- query: User query
- context: Optional pre-fetched context edges
- conversation_history: Optional conversation history string
- max_iter: Maximum CoT iterations
- response_model: Type for structured output (str for plain text)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
@ -97,16 +121,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
completion = await generate_structured_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
@ -119,7 +148,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
@ -134,9 +163,110 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
# Load conversation history if enabled
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
question=query, answer=str(completion), context=context_text, triplets=triplets
)
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -1,20 +1,26 @@
import asyncio
from typing import Any, Optional, Type, List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.users.methods import get_default_user
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.modules.engine.models.node_set import NodeSet
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("GraphCompletionRetriever")
@ -118,6 +124,13 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- str: A string representing the resolved context from the retrieved triplets, or an
empty string if no triplets are found.
"""
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
if is_empty:
logger.warning("Search attempt on an empty knowledge graph")
return []
triplets = await self.get_triplets(query)
if len(triplets) == 0:
@ -132,6 +145,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
) -> List[str]:
"""
Generates a completion using graph connections context based on a query.
@ -142,6 +156,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- query (str): The query string for which a completion is generated.
- context (Optional[Any]): Optional context to use for generating the completion; if
not provided, context is retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
@ -155,19 +171,47 @@ class GraphCompletionRetriever(BaseGraphRetriever):
context_text = await resolve_edges_to_text(triplets)
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context_text),
generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
),
)
else:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:

View file

@ -116,8 +116,26 @@ class LexicalRetriever(BaseRetriever):
else:
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns context for the given query (retrieves if not provided)."""
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns context for the given query (retrieves if not provided).
Parameters:
-----------
- query (str): The query string to retrieve context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The context, either provided or retrieved.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -122,10 +122,17 @@ class NaturalLanguageRetriever(BaseRetriever):
query.
"""
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
if is_empty:
logger.warning("Search attempt on an empty knowledge graph")
return []
return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns a completion based on the query and context.
@ -139,6 +146,8 @@ class NaturalLanguageRetriever(BaseRetriever):
- query (str): The natural language query to get a completion from.
- context (Optional[Any]): The context in which to base the completion; if not
provided, it will be retrieved using the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
) -> Any:
"""
Generates a completion using summaries context.
@ -75,6 +77,8 @@ class SummariesRetriever(BaseRetriever):
- query (str): The search query for generating the completion.
- context (Optional[Any]): Optional context for the completion; if not provided,
will be retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -1,16 +1,22 @@
import os
import asyncio
from typing import Any, Optional, List, Type
from datetime import datetime
from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
from cognee.tasks.temporal_graph.models import QueryInterval
@ -73,7 +79,11 @@ class TemporalRetriever(GraphCompletionRetriever):
else:
base_directory = None
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
time_now = datetime.now().strftime("%d-%m-%Y")
system_prompt = render_prompt(
prompt_path, {"time_now": time_now}, base_directory=base_directory
)
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
@ -102,8 +112,6 @@ class TemporalRetriever(GraphCompletionRetriever):
graph_engine = await get_graph_engine()
triplets = []
if time_from and time_to:
ids = await graph_engine.collect_time_ids(time_from=time_from, time_to=time_to)
elif time_from:
@ -137,17 +145,63 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]:
"""Generates a response using the query and optional context."""
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
"""
Generates a response using the query and optional context.
Parameters:
-----------
- query (str): The query string for which a completion is generated.
- context (Optional[str]): Optional context to use; if None, it will be
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- List[str]: A list containing the generated completion.
"""
if not context:
context = await self.get_context(query=query)
if context:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]

View file

@ -71,7 +71,7 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
)

View file

@ -1,23 +1,49 @@
from typing import Optional
from typing import Optional, Type, Any
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
if conversation_history:
#:TODO: I would separate the history and put it into the system prompt but we have to test what works best with longer convos
system_prompt = conversation_history + "\nTASK:" + system_prompt
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)

View file

@ -0,0 +1,156 @@
from typing import Optional, List, Dict, Any
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
from cognee.infrastructure.databases.exceptions import CacheConnectionError
from cognee.shared.logging_utils import get_logger
logger = get_logger("session_cache")
async def save_conversation_history(
query: str,
context_summary: str,
answer: str,
session_id: Optional[str] = None,
) -> bool:
"""
Saves Q&A interaction to the session cache if user is authenticated and caching is enabled.
Handles cache unavailability gracefully by logging warnings instead of failing.
Parameters:
-----------
- query (str): The user's query/question.
- context_summary (str): Summarized context used for generating the answer.
- answer (str): The generated answer/completion.
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
Returns:
--------
- bool: True if successfully saved to cache, False otherwise.
"""
try:
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
if not (user_id and cache_config.caching):
logger.debug("Session caching disabled or user not authenticated")
return False
if session_id is None:
session_id = "default_session"
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
cache_engine = get_cache_engine()
if cache_engine is None:
logger.warning("Cache engine not available, skipping session save")
return False
await cache_engine.add_qa(
str(user_id),
session_id=session_id,
question=query,
context=context_summary,
answer=answer,
)
logger.info(
f"Successfully saved Q&A to session cache: user_id={user_id}, session_id={session_id}"
)
return True
except CacheConnectionError as e:
logger.warning(f"Cache unavailable, continuing without session save: {e.message}")
return False
except Exception as e:
logger.error(
f"Unexpected error saving to session cache: {type(e).__name__}: {str(e)}. Continuing without caching."
)
return False
async def get_conversation_history(
session_id: Optional[str] = None,
) -> str:
"""
Retrieves conversation history from cache and formats it as text.
Returns formatted conversation history with time, question, context, and answer
for the last N Q&A pairs (N is determined by cache engine default).
Parameters:
-----------
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
Returns:
--------
- str: Formatted conversation history string, or empty string if no history or error.
Format:
-------
Previous conversation:
[2024-01-15 10:30:45]
QUESTION: What is X?
CONTEXT: X is a concept...
ANSWER: X is...
[2024-01-15 10:31:20]
QUESTION: How does Y work?
CONTEXT: Y is related to...
ANSWER: Y works by...
"""
try:
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
if not (user_id and cache_config.caching):
logger.debug("Session caching disabled or user not authenticated")
return ""
if session_id is None:
session_id = "default_session"
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
cache_engine = get_cache_engine()
if cache_engine is None:
logger.warning("Cache engine not available, skipping conversation history retrieval")
return ""
history_entries = await cache_engine.get_latest_qa(str(user_id), session_id)
if not history_entries:
logger.debug("No conversation history found")
return ""
history_text = "Previous conversation:\n\n"
for entry in history_entries:
history_text += f"[{entry.get('time', 'Unknown time')}]\n"
history_text += f"QUESTION: {entry.get('question', '')}\n"
history_text += f"CONTEXT: {entry.get('context', '')}\n"
history_text += f"ANSWER: {entry.get('answer', '')}\n\n"
logger.debug(f"Retrieved {len(history_entries)} conversation history entries")
return history_text
except CacheConnectionError as e:
logger.warning(f"Cache unavailable, continuing without conversation history: {e.message}")
return ""
except Exception as e:
logger.warning(
f"Unexpected error retrieving conversation history: {type(e).__name__}: {str(e)}"
)
return ""

View file

@ -1,12 +1,16 @@
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
from .get_search_type_tools import get_search_type_tools
logger = get_logger()
async def no_access_control_search(
query_type: SearchType,
@ -19,6 +23,7 @@ async def no_access_control_search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@ -31,6 +36,12 @@ async def no_access_control_search(
save_interaction=save_interaction,
last_k=last_k,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
if is_empty:
# TODO: we can log here, but not all search types use graph. Still keeping this here for reviewer input
logger.warning("Search attempt on an empty knowledge graph")
if len(search_tools) == 2:
[get_completion, get_context] = search_tools
@ -38,7 +49,7 @@ async def no_access_control_search(
return None, await get_context(query_text), []
context = await get_context(query_text)
result = await get_completion(query_text, context)
result = await get_completion(query_text, context, session_id=session_id)
else:
unknown_tool = search_tools[0]
result = await unknown_tool(query_text)

View file

@ -5,6 +5,8 @@ from uuid import UUID
from fastapi.encoders import jsonable_encoder
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
@ -22,11 +24,13 @@ from cognee.modules.data.models import Dataset
from cognee.modules.data.methods.get_authorized_existing_datasets import (
get_authorized_existing_datasets,
)
from cognee import __version__ as cognee_version
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result
logger = get_logger()
async def search(
query_text: str,
@ -42,6 +46,7 @@ async def search(
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@ -59,7 +64,14 @@ async def search(
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
query = await log_query(query_text, query_type.value, user.id)
send_telemetry("cognee.search EXECUTION STARTED", user.id)
send_telemetry(
"cognee.search EXECUTION STARTED",
user.id,
additional_properties={
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
@ -77,6 +89,7 @@ async def search(
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
)
else:
search_results = [
@ -91,10 +104,18 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
session_id=session_id,
)
]
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
send_telemetry(
"cognee.search EXECUTION COMPLETED",
user.id,
additional_properties={
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
await log_result(
query.id,
@ -195,6 +216,7 @@ async def authorized_search(
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@ -221,6 +243,7 @@ async def authorized_search(
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
)
context = {}
@ -263,7 +286,7 @@ async def authorized_search(
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
@ -280,6 +303,7 @@ async def authorized_search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
session_id=session_id,
)
return search_results
@ -298,6 +322,7 @@ async def search_in_datasets_context(
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -317,10 +342,30 @@ async def search_in_datasets_context(
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
if is_empty:
# TODO: we can log here, but not all search types use graph. Still keeping this here for reviewer input
from cognee.modules.data.methods import get_dataset_data
dataset_data = await get_dataset_data(dataset.id)
if len(dataset_data) > 0:
logger.warning(
f"Dataset '{dataset.name}' has {len(dataset_data)} data item(s) but the knowledge graph is empty. "
"Please run cognify to process the data before searching."
)
else:
logger.warning(
"Search attempt on an empty knowledge graph - no data has been added to this dataset"
)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
@ -340,7 +385,7 @@ async def search_in_datasets_context(
return None, await get_context(query_text), [dataset]
search_context = context or await get_context(query_text)
search_result = await get_completion(query_text, search_context)
search_result = await get_completion(query_text, search_context, session_id=session_id)
return search_result, search_context, [dataset]
else:
@ -365,6 +410,7 @@ async def search_in_datasets_context(
last_k=last_k,
only_context=only_context,
context=context,
session_id=session_id,
)
)

View file

@ -27,12 +27,7 @@ async def get_default_user() -> SimpleNamespace:
if user is None:
return await create_default_user()
# We return a SimpleNamespace to have the same user type as our SaaS
# SimpleNamespace is just a dictionary which can be accessed through attributes
auth_data = SimpleNamespace(
id=user.id, email=user.email, tenant_id=user.tenant_id, roles=[]
)
return auth_data
return user
except Exception as error:
if "principals" in str(error.args):
raise DatabaseNotCreatedError() from error

View file

@ -16,17 +16,17 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
nodes_list = []
color_map = {
"Entity": "#f47710",
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"ColumnValue": "#13613a",
"SchemaTable": "#f47710",
"DatabaseSchema": "#6510f4",
"SchemaRelationship": "#13613a",
"default": "#D3D3D3",
"Entity": "#5C10F4",
"EntityType": "#A550FF",
"DocumentChunk": "#0DFF00",
"TextSummary": "#5C10F4",
"TableRow": "#A550FF",
"TableType": "#5C10F4",
"ColumnValue": "#757470",
"SchemaTable": "#A550FF",
"DatabaseSchema": "#5C10F4",
"SchemaRelationship": "#323332",
"default": "#D8D8D8",
}
for node_id, node_info in nodes_data:
@ -98,16 +98,19 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
<head>
<meta charset="utf-8">
<script src="https://d3js.org/d3.v5.min.js"></script>
<script src="https://d3js.org/d3-contour.v1.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
.links line.weighted { stroke: rgba(255, 215, 0, 0.7); }
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.links line { stroke: rgba(160, 160, 160, 0.25); stroke-width: 1.5px; stroke-linecap: round; }
.links line.weighted { stroke: rgba(255, 215, 0, 0.4); }
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.45); }
.nodes circle { stroke: white; stroke-width: 0.5px; }
.node-label { font-size: 5px; font-weight: bold; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; paint-order: stroke; stroke: rgba(50,51,50,0.75); stroke-width: 1px; }
.density path { mix-blend-mode: screen; }
.tooltip {
position: absolute;
@ -125,11 +128,32 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
max-width: 300px;
word-wrap: break-word;
}
#info-panel {
position: fixed;
left: 12px;
top: 12px;
width: 340px;
max-height: calc(100vh - 24px);
overflow: auto;
background: rgba(50, 51, 50, 0.7);
backdrop-filter: blur(6px);
border: 1px solid rgba(216, 216, 216, 0.35);
border-radius: 8px;
color: #F4F4F4;
padding: 12px 14px;
z-index: 1100;
}
#info-panel h3 { margin: 0 0 8px 0; font-size: 14px; color: #F4F4F4; }
#info-panel .kv { font-size: 12px; line-height: 1.4; }
#info-panel .kv .k { color: #D8D8D8; }
#info-panel .kv .v { color: #F4F4F4; }
#info-panel .placeholder { opacity: 0.7; font-size: 12px; }
</style>
</head>
<body>
<svg></svg>
<div class="tooltip" id="tooltip"></div>
<div id="info-panel"><div class="placeholder">Hover a node or edge to inspect details</div></div>
<script>
var nodes = {nodes};
var links = {links};
@ -140,19 +164,141 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
var container = svg.append("g");
var tooltip = d3.select("#tooltip");
var infoPanel = d3.select('#info-panel');
function renderInfo(title, entries){
function esc(s){ return String(s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
var html = '<h3>' + esc(title) + '</h3>';
html += '<div class="kv">';
entries.forEach(function(e){
html += '<div><span class="k">' + esc(e.k) + ':</span> <span class="v">' + esc(e.v) + '</span></div>';
});
html += '</div>';
infoPanel.html(html);
}
function pickDescription(obj){
if (!obj) return null;
var keys = ['description','summary','text','content'];
for (var i=0; i<keys.length; i++){
var v = obj[keys[i]];
if (typeof v === 'string' && v.trim()) return v.trim();
}
return null;
}
function truncate(s, n){ if (!s) return s; return s.length > n ? (s.slice(0, n) + '') : s; }
function renderNodeInfo(n){
var entries = [];
if (n.name) entries.push({k:'Name', v: n.name});
if (n.type) entries.push({k:'Type', v: n.type});
if (n.id) entries.push({k:'ID', v: n.id});
var desc = pickDescription(n) || pickDescription(n.properties);
if (desc) entries.push({k:'Description', v: truncate(desc.replace(/\s+/g,' ').trim(), 280)});
if (n.properties) {
Object.keys(n.properties).slice(0, 12).forEach(function(key){
var v = n.properties[key];
if (v !== undefined && v !== null && typeof v !== 'object') entries.push({k: key, v: String(v)});
});
}
renderInfo(n.name || 'Node', entries);
}
function renderEdgeInfo(e){
var entries = [];
if (e.relation) entries.push({k:'Relation', v: e.relation});
if (e.weight !== undefined && e.weight !== null) entries.push({k:'Weight', v: e.weight});
if (e.all_weights && Object.keys(e.all_weights).length){
Object.keys(e.all_weights).slice(0, 8).forEach(function(k){ entries.push({k: 'w.'+k, v: e.all_weights[k]}); });
}
if (e.relationship_type) entries.push({k:'Type', v: e.relationship_type});
var edesc = pickDescription(e.edge_info);
if (edesc) entries.push({k:'Description', v: truncate(edesc.replace(/\s+/g,' ').trim(), 280)});
renderInfo('Edge', entries);
}
// Basic runtime diagnostics
console.log('[Cognee Visualization] nodes:', nodes ? nodes.length : 0, 'links:', links ? links.length : 0);
window.addEventListener('error', function(e){
try {
tooltip.html('<strong>Error:</strong> ' + e.message)
.style('left', '12px')
.style('top', '12px')
.style('opacity', 1);
} catch(_) {}
});
// Normalize node IDs and link endpoints for robustness
function resolveId(d){ return (d && (d.id || d.node_id || d.uuid || d.external_id || d.name)) || undefined; }
if (Array.isArray(nodes)) {
nodes.forEach(function(n){ var id = resolveId(n); if (id !== undefined) n.id = id; });
}
if (Array.isArray(links)) {
links.forEach(function(l){
if (typeof l.source === 'object') l.source = resolveId(l.source);
if (typeof l.target === 'object') l.target = resolveId(l.target);
});
}
if (!nodes || nodes.length === 0) {
container.append('text')
.attr('x', width / 2)
.attr('y', height / 2)
.attr('fill', '#fff')
.attr('font-size', 14)
.attr('text-anchor', 'middle')
.text('No graph data available');
}
// Visual defs - reusable glow
var defs = svg.append("defs");
var glow = defs.append("filter").attr("id", "glow")
.attr("x", "-30%")
.attr("y", "-30%")
.attr("width", "160%")
.attr("height", "160%");
glow.append("feGaussianBlur").attr("stdDeviation", 8).attr("result", "coloredBlur");
var feMerge = glow.append("feMerge");
feMerge.append("feMergeNode").attr("in", "coloredBlur");
feMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Stronger glow for hovered adjacency
var glowStrong = defs.append("filter").attr("id", "glow-strong")
.attr("x", "-40%")
.attr("y", "-40%")
.attr("width", "180%")
.attr("height", "180%");
glowStrong.append("feGaussianBlur").attr("stdDeviation", 14).attr("result", "coloredBlur");
var feMerge2 = glowStrong.append("feMerge");
feMerge2.append("feMergeNode").attr("in", "coloredBlur");
feMerge2.append("feMergeNode").attr("in", "SourceGraphic");
var currentTransform = d3.zoomIdentity;
var densityZoomTimer = null;
var isInteracting = false;
var labelBaseSize = 10;
function getGroupKey(d){ return d && (d.type || d.category || d.group || d.color) || 'default'; }
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-275))
.force("link", d3.forceLink(links).id(function(d){ return d.id; }).distance(100).strength(0.2))
.force("charge", d3.forceManyBody().strength(-180))
.force("collide", d3.forceCollide().radius(16).iterations(2))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
.force("x", d3.forceX().strength(0.06).x(width / 2))
.force("y", d3.forceY().strength(0.06).y(height / 2))
.alphaDecay(0.06)
.velocityDecay(0.6);
// Density layer (sibling of container to avoid double transforms)
var densityLayer = svg.append("g")
.attr("class", "density")
.style("pointer-events", "none");
if (densityLayer.lower) densityLayer.lower();
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.style("opacity", 0)
.style("pointer-events", "none")
.attr("stroke-width", d => {
if (d.weight) return Math.max(2, d.weight * 5);
if (d.all_weights && Object.keys(d.all_weights).length > 0) {
@ -168,6 +314,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
})
.on("mouseover", function(d) {
// Create tooltip content for edge
renderEdgeInfo(d);
var content = "<strong>Edge Information</strong><br/>";
content += "Relationship: " + d.relation + "<br/>";
@ -212,6 +359,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.style("opacity", 0)
.text(d => {
var label = d.relation;
if (d.all_weights && Object.keys(d.all_weights).length > 1) {
@ -232,21 +380,225 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.data(nodes)
.enter().append("g");
// Color fallback by type when d.color is missing
var colorByType = {
"Entity": "#5C10F4",
"EntityType": "#A550FF",
"DocumentChunk": "#0DFF00",
"TextSummary": "#5C10F4",
"TableRow": "#A550FF",
"TableType": "#5C10F4",
"ColumnValue": "#757470",
"SchemaTable": "#A550FF",
"DatabaseSchema": "#5C10F4",
"SchemaRelationship": "#323332"
};
var node = nodeGroup.append("circle")
.attr("r", 13)
.attr("fill", d => d.color)
.attr("fill", function(d){ return d.color || colorByType[d.type] || "#D3D3D3"; })
.style("filter", "url(#glow)")
.attr("shape-rendering", "geometricPrecision")
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("drag", function(d){ dragged(d); updateDensity(); showAdjacency(d); })
.on("end", dragended));
nodeGroup.append("text")
// Show links only for hovered node adjacency
function isAdjacent(linkDatum, nodeId) {
var sid = linkDatum && linkDatum.source && (linkDatum.source.id || linkDatum.source);
var tid = linkDatum && linkDatum.target && (linkDatum.target.id || linkDatum.target);
return sid === nodeId || tid === nodeId;
}
function showAdjacency(d) {
var nodeId = d && (d.id || d.node_id || d.uuid || d.external_id || d.name);
if (!nodeId) return;
// Build neighbor set
var neighborIds = {};
neighborIds[nodeId] = true;
for (var i = 0; i < links.length; i++) {
var l = links[i];
var sid = l && l.source && (l.source.id || l.source);
var tid = l && l.target && (l.target.id || l.target);
if (sid === nodeId) neighborIds[tid] = true;
if (tid === nodeId) neighborIds[sid] = true;
}
link
.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 0.95 : 0; })
.style("stroke", function(l){ return isAdjacent(l, nodeId) ? "rgba(255,255,255,0.95)" : null; })
.style("stroke-width", function(l){ return isAdjacent(l, nodeId) ? 2.5 : 1.5; });
edgeLabels.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 1 : 0; });
densityLayer.style("opacity", 0.35);
// Highlight neighbor nodes and dim others
node
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.25; })
.style("filter", function(n){ return neighborIds[n.id] ? "url(#glow-strong)" : "url(#glow)"; })
.attr("r", function(n){ return neighborIds[n.id] ? 15 : 13; });
// Raise highlighted nodes
node.filter(function(n){ return neighborIds[n.id]; }).raise();
// Neighbor labels brighter
nodeGroup.select("text")
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.2; })
.style("font-size", function(n){
var size = neighborIds[n.id] ? Math.min(22, labelBaseSize * 1.25) : labelBaseSize;
return size + "px";
});
}
function clearAdjacency() {
link.style("opacity", 0)
.style("stroke", null)
.style("stroke-width", 1.5);
edgeLabels.style("opacity", 0);
densityLayer.style("opacity", 1);
node
.style("opacity", 1)
.style("filter", "url(#glow)")
.attr("r", 13);
nodeGroup.select("text")
.style("opacity", 1)
.style("font-size", labelBaseSize + "px");
}
node.on("mouseover", function(d){ showAdjacency(d); })
.on("mouseout", function(){ clearAdjacency(); });
node.on("mouseover", function(d){ renderNodeInfo(d); tooltip.style('opacity', 0); });
// Also bind on the group so labels trigger adjacency too
nodeGroup.on("mouseover", function(d){ showAdjacency(d); })
.on("mouseout", function(){ clearAdjacency(); });
// Density always on; no hover gating
// Add labels sparsely to reduce clutter (every ~50th node), and truncate long text
nodeGroup
.filter(function(d, i){ return i % 14 === 0; })
.append("text")
.attr("class", "node-label")
.attr("dy", 4)
.attr("text-anchor", "middle")
.text(d => d.name);
.text(function(d){
var s = d && d.name ? String(d.name) : '';
return s.length > 40 ? (s.slice(0, 40) + "") : s;
})
.style("font-size", labelBaseSize + "px");
node.append("title").text(d => JSON.stringify(d));
function applyLabelSize() {
var k = (currentTransform && currentTransform.k) || 1;
// Keep labels readable across zoom levels and hide when too small
labelBaseSize = Math.max(7, Math.min(18, 10 / Math.sqrt(k)));
nodeGroup.select("text")
.style("font-size", labelBaseSize + "px")
.style("display", (k < 0.35 ? "none" : null));
}
// Density cloud computation (throttled)
var densityTick = 0;
var geoPath = d3.geoPath().projection(null);
var MAX_POINTS_PER_GROUP = 400;
function updateDensity() {
try {
if (isInteracting) return; // skip during interaction for smoother UX
if (typeof d3 === 'undefined' || typeof d3.contourDensity !== 'function') {
return; // d3-contour not available; skip gracefully
}
if (!nodes || nodes.length === 0) return;
var usable = nodes.filter(function(d){ return d && typeof d.x === 'number' && isFinite(d.x) && typeof d.y === 'number' && isFinite(d.y); });
if (usable.length < 3) return; // not enough positioned points yet
var t = currentTransform || d3.zoomIdentity;
if (t.k && t.k < 0.08) {
// Skip density at extreme zoom-out to avoid numerical instability/perf issues
densityLayer.selectAll('*').remove();
return;
}
function hexToRgb(hex){
if (!hex) return {r: 0, g: 200, b: 255};
var c = hex.replace('#','');
if (c.length === 3) c = c.split('').map(function(x){ return x+x; }).join('');
var num = parseInt(c, 16);
return { r: (num >> 16) & 255, g: (num >> 8) & 255, b: num & 255 };
}
// Build groups across all nodes
var groups = {};
for (var i = 0; i < usable.length; i++) {
var k = getGroupKey(usable[i]);
if (!groups[k]) groups[k] = [];
groups[k].push(usable[i]);
}
densityLayer.selectAll('*').remove();
Object.keys(groups).forEach(function(key){
var arr = groups[key];
if (!arr || arr.length < 3) return;
// Transform positions into screen space and sample to cap cost
var arrT = [];
var step = Math.max(1, Math.floor(arr.length / MAX_POINTS_PER_GROUP));
for (var j = 0; j < arr.length; j += step) {
var nx = t.applyX(arr[j].x);
var ny = t.applyY(arr[j].y);
if (isFinite(nx) && isFinite(ny)) {
arrT.push({ x: nx, y: ny, type: arr[j].type, color: arr[j].color });
}
}
if (arrT.length < 3) return;
// Compute adaptive bandwidth based on group spread
var cx = 0, cy = 0;
for (var k = 0; k < arrT.length; k++){ cx += arrT[k].x; cy += arrT[k].y; }
cx /= arrT.length; cy /= arrT.length;
var sumR = 0;
for (var k2 = 0; k2 < arrT.length; k2++){
var dx = arrT[k2].x - cx, dy = arrT[k2].y - cy;
sumR += Math.sqrt(dx*dx + dy*dy);
}
var avgR = sumR / arrT.length;
var dynamicBandwidth = Math.max(12, Math.min(80, avgR));
var densityBandwidth = dynamicBandwidth / (t.k || 1);
var contours = d3.contourDensity()
.x(function(d){ return d.x; })
.y(function(d){ return d.y; })
.size([width, height])
.bandwidth(densityBandwidth)
.thresholds(8)
(arrT);
if (!contours || contours.length === 0) return;
var maxVal = d3.max(contours, function(d){ return d.value; }) || 1;
// Use the first node color in the group or fallback neon palette
var baseColor = (arr.find(function(d){ return d && d.color; }) || {}).color || '#00c8ff';
var rgb = hexToRgb(baseColor);
var g = densityLayer.append('g').attr('data-group', key);
g.selectAll('path')
.data(contours)
.enter()
.append('path')
.attr('d', geoPath)
.attr('fill', 'rgb(' + rgb.r + ',' + rgb.g + ',' + rgb.b + ')')
.attr('stroke', 'none')
.style('opacity', function(d){
var v = maxVal ? (d.value / maxVal) : 0;
var alpha = Math.pow(Math.max(0, Math.min(1, v)), 1.6); // accentuate clusters
return 0.65 * alpha; // up to 0.65 opacity at peak density
})
.style('filter', 'blur(2px)');
});
} catch (e) {
// Reduce impact of any runtime errors during zoom
console.warn('Density update failed:', e);
}
}
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
@ -266,16 +618,29 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
densityTick += 1;
if (densityTick % 24 === 0) updateDensity();
});
svg.call(d3.zoom().on("zoom", function() {
container.attr("transform", d3.event.transform);
}));
var zoomBehavior = d3.zoom()
.on("start", function(){ isInteracting = true; densityLayer.style("opacity", 0.2); })
.on("zoom", function(){
currentTransform = d3.event.transform;
container.attr("transform", currentTransform);
})
.on("end", function(){
if (densityZoomTimer) clearTimeout(densityZoomTimer);
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
});
svg.call(zoomBehavior);
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
isInteracting = true;
densityLayer.style("opacity", 0.2);
}
function dragged(d) {
@ -287,6 +652,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
if (densityZoomTimer) clearTimeout(densityZoomTimer);
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
}
window.addEventListener("resize", function() {
@ -295,7 +662,13 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
updateDensity();
applyLabelSize();
});
// Initial density draw
updateDensity();
applyLabelSize();
</script>
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
@ -305,8 +678,12 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
# Safely embed JSON inside <script> by escaping </ to avoid prematurely closing the tag
def _safe_json_embed(obj):
return json.dumps(obj).replace("</", "<\\/")
html_content = html_template.replace("{nodes}", _safe_json_embed(nodes_list))
html_content = html_content.replace("{links}", _safe_json_embed(links_list))
if not destination_file_path:
home_dir = os.path.expanduser("~")

View file

@ -7,6 +7,6 @@ class IngestionError(CogneeValidationError):
self,
message: str = "Failed to load data.",
name: str = "IngestionError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -1,6 +1,7 @@
import os
import sys
import logging
import tempfile
import structlog
import traceback
import platform
@ -76,9 +77,38 @@ log_levels = {
# Track if structlog logging has been configured
_is_structlog_configured = False
# Path to logs directory
LOGS_DIR = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "logs"))
LOGS_DIR.mkdir(exist_ok=True) # Create logs dir if it doesn't exist
def resolve_logs_dir():
"""Resolve a writable logs directory.
Priority:
1) BaseConfig.logs_root_directory (respects COGNEE_LOGS_DIR)
2) /tmp/cognee_logs (default, best-effort create)
Returns a Path or None if none are writable/creatable.
"""
from cognee.base_config import get_base_config
base_config = get_base_config()
logs_root_directory = Path(base_config.logs_root_directory)
try:
logs_root_directory.mkdir(parents=True, exist_ok=True)
if os.access(logs_root_directory, os.W_OK):
return logs_root_directory
except Exception:
pass
try:
tmp_log_path = Path(os.path.join("/tmp", "cognee_logs"))
tmp_log_path.mkdir(parents=True, exist_ok=True)
if os.access(tmp_log_path, os.W_OK):
return tmp_log_path
except Exception:
pass
return None
# Maximum number of log files to keep
MAX_LOG_FILES = 10
@ -430,28 +460,38 @@ def setup_logging(log_level=None, name=None):
stream_handler.setFormatter(console_formatter)
stream_handler.setLevel(log_level)
root_logger = logging.getLogger()
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(stream_handler)
# Note: root logger needs to be set at NOTSET to allow all messages through and specific stream and file handlers
# can define their own levels.
root_logger.setLevel(logging.NOTSET)
# Resolve logs directory with env and safe fallbacks
logs_dir = resolve_logs_dir()
# Check if we already have a log file path from the environment
# NOTE: environment variable must be used here as it allows us to
# log to a single file with a name based on a timestamp in a multiprocess setting.
# Without it, we would have a separate log file for every process.
log_file_path = os.environ.get("LOG_FILE_NAME")
if not log_file_path:
if not log_file_path and logs_dir is not None:
# Create a new log file name with the cognee start time
start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file_path = os.path.join(LOGS_DIR, f"{start_time}.log")
log_file_path = str((logs_dir / f"{start_time}.log").resolve())
os.environ["LOG_FILE_NAME"] = log_file_path
# Create a file handler that uses our custom PlainFileHandler
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
file_handler.setLevel(DEBUG)
# Configure root logger
root_logger = logging.getLogger()
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(stream_handler)
root_logger.addHandler(file_handler)
root_logger.setLevel(log_level)
try:
# Create a file handler that uses our custom PlainFileHandler
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
file_handler.setLevel(DEBUG)
root_logger.addHandler(file_handler)
except Exception as e:
# Note: Exceptions happen in case of read only file systems or log file path poiting to location where it does
# not have write permission. Logging to file is not mandatory so we just log a warning to console.
root_logger.warning(f"Warning: Could not create log file handler at {log_file_path}: {e}")
if log_level > logging.DEBUG:
import warnings
@ -466,7 +506,8 @@ def setup_logging(log_level=None, name=None):
)
# Clean up old log files, keeping only the most recent ones
cleanup_old_logs(LOGS_DIR, MAX_LOG_FILES)
if logs_dir is not None:
cleanup_old_logs(logs_dir, MAX_LOG_FILES)
# Mark logging as configured
_is_structlog_configured = True
@ -490,6 +531,10 @@ def setup_logging(log_level=None, name=None):
# Get a configured logger and log system information
logger = structlog.get_logger(name if name else __name__)
if logs_dir is not None:
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
# Detailed initialization for regular usage
logger.info(
"Logging initialized",

Some files were not shown because too many files have changed in this diff Show more