merged dev into this branch
This commit is contained in:
commit
d0a3bfd39f
147 changed files with 10201 additions and 6233 deletions
|
|
@ -242,13 +242,14 @@ LITELLM_LOG="ERROR"
|
|||
|
||||
########## Local LLM via Ollama ###############################################
|
||||
|
||||
|
||||
#LLM_API_KEY ="ollama"
|
||||
#LLM_MODEL="llama3.1:8b"
|
||||
#LLM_PROVIDER="ollama"
|
||||
#LLM_ENDPOINT="http://localhost:11434/v1"
|
||||
#EMBEDDING_PROVIDER="ollama"
|
||||
#EMBEDDING_MODEL="nomic-embed-text:latest"
|
||||
#EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings"
|
||||
#EMBEDDING_ENDPOINT="http://localhost:11434/api/embed"
|
||||
#EMBEDDING_DIMENSIONS=768
|
||||
#HUGGINGFACE_TOKENIZER="nomic-ai/nomic-embed-text-v1.5"
|
||||
|
||||
|
|
|
|||
11
.github/workflows/basic_tests.yml
vendored
11
.github/workflows/basic_tests.yml
vendored
|
|
@ -123,6 +123,7 @@ jobs:
|
|||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
extra-dependencies: "scraping"
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: uv run pytest cognee/tests/integration/
|
||||
|
|
@ -161,11 +162,11 @@ jobs:
|
|||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
||||
BAML_LLM_PROVIDER: azure-openai
|
||||
BAML_LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
BAML_LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
BAML_LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
BAML_LLM_PROVIDER: openai
|
||||
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
BAML_LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
BAML_LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
# BAML_LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
|
|||
20
.github/workflows/dockerhub-mcp.yml
vendored
20
.github/workflows/dockerhub-mcp.yml
vendored
|
|
@ -7,14 +7,29 @@ on:
|
|||
|
||||
jobs:
|
||||
docker-build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: Default
|
||||
labels:
|
||||
- docker_build_runner
|
||||
|
||||
steps:
|
||||
- name: Check and free disk space before build
|
||||
run: |
|
||||
echo "=== Before cleanup ==="
|
||||
df -h
|
||||
echo "Removing unused preinstalled SDKs to free space..."
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc || true
|
||||
docker system prune -af || true
|
||||
echo "=== After cleanup ==="
|
||||
df -h
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --root /tmp/buildkit
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
|
@ -34,7 +49,7 @@ jobs:
|
|||
|
||||
- name: Build and push
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
|
@ -45,5 +60,6 @@ jobs:
|
|||
cache-from: type=registry,ref=cognee/cognee-mcp:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee-mcp:buildcache,mode=max
|
||||
|
||||
|
||||
- name: Image digest
|
||||
run: echo ${{ steps.build.outputs.digest }}
|
||||
|
|
|
|||
119
.github/workflows/e2e_tests.yml
vendored
119
.github/workflows/e2e_tests.yml
vendored
|
|
@ -332,6 +332,125 @@ jobs:
|
|||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py
|
||||
|
||||
test-entity-extraction:
|
||||
name: Test Entity Extraction
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Entity Extraction Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
|
||||
|
||||
test-feedback-enrichment:
|
||||
name: Test Feedback Enrichment
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Feedback Enrichment Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||
|
||||
run_conversation_sessions_test:
|
||||
name: Conversation sessions test
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
POSTGRES_USER: cognee
|
||||
POSTGRES_PASSWORD: cognee
|
||||
POSTGRES_DB: cognee_db
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
redis:
|
||||
image: redis:7
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 5s
|
||||
--health-timeout 3s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres redis"
|
||||
|
||||
- name: Run Conversation session tests
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
DB_PORT: 5432
|
||||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
|
||||
test-load:
|
||||
name: Test Load
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
78
.github/workflows/scorecard.yml
vendored
Normal file
78
.github/workflows/scorecard.yml
vendored
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# This workflow uses actions that are not certified by GitHub. They are provided
|
||||
# by a third-party and are governed by separate terms of service, privacy
|
||||
# policy, and support documentation.
|
||||
|
||||
name: Scorecard supply-chain security
|
||||
on:
|
||||
# For Branch-Protection check. Only the default branch is supported. See
|
||||
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
|
||||
branch_protection_rule:
|
||||
# To guarantee Maintained check is occasionally updated. See
|
||||
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
|
||||
schedule:
|
||||
- cron: '35 8 * * 2'
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Declare default permissions as read only.
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
analysis:
|
||||
name: Scorecard analysis
|
||||
runs-on: ubuntu-latest
|
||||
# `publish_results: true` only works when run from the default branch. conditional can be removed if disabled.
|
||||
if: github.event.repository.default_branch == github.ref_name || github.event_name == 'pull_request'
|
||||
permissions:
|
||||
# Needed to upload the results to code-scanning dashboard.
|
||||
security-events: write
|
||||
# Needed to publish results and get a badge (see publish_results below).
|
||||
id-token: write
|
||||
# Uncomment the permissions below if installing in a private repository.
|
||||
# contents: read
|
||||
# actions: read
|
||||
|
||||
steps:
|
||||
- name: "Checkout code"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: "Run analysis"
|
||||
uses: ossf/scorecard-action@f49aabe0b5af0936a0987cfb85d86b75731b0186 # v2.4.1
|
||||
with:
|
||||
results_file: results.sarif
|
||||
results_format: sarif
|
||||
# (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
|
||||
# - you want to enable the Branch-Protection check on a *public* repository, or
|
||||
# - you are installing Scorecard on a *private* repository
|
||||
# To create the PAT, follow the steps in https://github.com/ossf/scorecard-action?tab=readme-ov-file#authentication-with-fine-grained-pat-optional.
|
||||
# repo_token: ${{ secrets.SCORECARD_TOKEN }}
|
||||
|
||||
# Public repositories:
|
||||
# - Publish results to OpenSSF REST API for easy access by consumers
|
||||
# - Allows the repository to include the Scorecard badge.
|
||||
# - See https://github.com/ossf/scorecard-action#publishing-results.
|
||||
# For private repositories:
|
||||
# - `publish_results` will always be set to `false`, regardless
|
||||
# of the value entered here.
|
||||
publish_results: true
|
||||
|
||||
# (Optional) Uncomment file_mode if you have a .gitattributes with files marked export-ignore
|
||||
# file_mode: git
|
||||
|
||||
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
|
||||
# format to the repository Actions tab.
|
||||
- name: "Upload artifact"
|
||||
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1
|
||||
with:
|
||||
name: SARIF file
|
||||
path: results.sarif
|
||||
retention-days: 5
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard (optional).
|
||||
# Commenting out will disable upload of results to your repo's Code Scanning dashboard
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
28
.github/workflows/temporal_graph_tests.yml
vendored
28
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -34,10 +34,9 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
@ -73,10 +72,9 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
@ -125,10 +123,9 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
@ -192,10 +189,9 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ on:
|
|||
required: false
|
||||
type: string
|
||||
default: '["3.10.x", "3.12.x", "3.13.x"]'
|
||||
os:
|
||||
required: false
|
||||
type: string
|
||||
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
|
||||
secrets:
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
|
|
@ -43,7 +47,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -79,7 +83,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -115,7 +119,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -151,7 +155,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -180,7 +184,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -217,7 +221,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
4
.github/workflows/test_ollama.yml
vendored
4
.github/workflows/test_ollama.yml
vendored
|
|
@ -75,7 +75,7 @@ jobs:
|
|||
{ "role": "user", "content": "Whatever I say, answer with Yes." }
|
||||
]
|
||||
}'
|
||||
curl -X POST http://127.0.0.1:11434/v1/embeddings \
|
||||
curl -X POST http://127.0.0.1:11434/api/embed \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "avr/sfr-embedding-mistral:latest",
|
||||
|
|
@ -98,7 +98,7 @@ jobs:
|
|||
LLM_MODEL: "phi4"
|
||||
EMBEDDING_PROVIDER: "ollama"
|
||||
EMBEDDING_MODEL: "avr/sfr-embedding-mistral:latest"
|
||||
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embeddings"
|
||||
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embed"
|
||||
EMBEDDING_DIMENSIONS: "4096"
|
||||
HUGGINGFACE_TOKENIZER: "Salesforce/SFR-Embedding-Mistral"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
|
|
|||
25
.github/workflows/test_suites.yml
vendored
25
.github/workflows/test_suites.yml
vendored
|
|
@ -1,4 +1,6 @@
|
|||
name: Test Suites
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -80,12 +82,22 @@ jobs:
|
|||
uses: ./.github/workflows/notebooks_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
different-operating-systems-tests:
|
||||
name: Operating System and Python Tests
|
||||
different-os-tests-basic:
|
||||
name: OS and Python Tests Ubuntu
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
|
||||
os: '["ubuntu-22.04"]'
|
||||
secrets: inherit
|
||||
|
||||
different-os-tests-extended:
|
||||
name: OS and Python Tests Extended
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.13.x"]'
|
||||
os: '["macos-15", "windows-latest"]'
|
||||
secrets: inherit
|
||||
|
||||
# Matrix-based vector database tests
|
||||
|
|
@ -135,7 +147,8 @@ jobs:
|
|||
e2e-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
llm-tests,
|
||||
|
|
@ -155,7 +168,8 @@ jobs:
|
|||
cli-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
db-examples-tests,
|
||||
|
|
@ -176,7 +190,8 @@ jobs:
|
|||
"${{ needs.cli-tests.result }}" == "success" &&
|
||||
"${{ needs.graph-db-tests.result }}" == "success" &&
|
||||
"${{ needs.notebook-tests.result }}" == "success" &&
|
||||
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-basic.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-extended.result }}" == "success" &&
|
||||
"${{ needs.vector-db-tests.result }}" == "success" &&
|
||||
"${{ needs.example-tests.result }}" == "success" &&
|
||||
"${{ needs.db-examples-tests.result }}" == "success" &&
|
||||
|
|
|
|||
132
AGENTS.md
Normal file
132
AGENTS.md
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
## Repository Guidelines
|
||||
|
||||
This document summarizes how to work with the cognee repository: how it’s organized, how to build, test, lint, and contribute. It mirrors our actual tooling and CI while providing quick commands for local development.
|
||||
|
||||
## Project Structure & Module Organization
|
||||
|
||||
- `cognee/`: Core Python library and API.
|
||||
- `api/`: FastAPI application and versioned routers (add, cognify, memify, search, delete, users, datasets, responses, visualize, settings, sync, update, checks).
|
||||
- `cli/`: CLI entry points and subcommands invoked via `cognee` / `cognee-cli`.
|
||||
- `infrastructure/`: Databases, LLM providers, embeddings, loaders, and storage adapters.
|
||||
- `modules/`: Domain logic (graph, retrieval, ontology, users, processing, observability, etc.).
|
||||
- `tasks/`: Reusable tasks (e.g., code graph, web scraping, storage). Extend with new tasks here.
|
||||
- `eval_framework/`: Evaluation utilities and adapters.
|
||||
- `shared/`: Cross-cutting helpers (logging, settings, utils).
|
||||
- `tests/`: Unit, integration, CLI, and end-to-end tests organized by feature.
|
||||
- `__main__.py`: Entrypoint to route to CLI.
|
||||
- `cognee-mcp/`: Model Context Protocol server exposing cognee as MCP tools (SSE/HTTP/stdio). Contains its own README and Dockerfile.
|
||||
- `cognee-frontend/`: Next.js UI for local development and demos.
|
||||
- `distributed/`: Utilities for distributed execution (Modal, workers, queues).
|
||||
- `examples/`: Example scripts demonstrating the public APIs and features (graph, code graph, multimodal, permissions, etc.).
|
||||
- `notebooks/`: Jupyter notebooks for demos and tutorials.
|
||||
- `alembic/`: Database migrations for relational backends.
|
||||
|
||||
Notes:
|
||||
- Co-locate feature-specific helpers under their respective package (`modules/`, `infrastructure/`, or `tasks/`).
|
||||
- Extend the system by adding new tasks, loaders, or retrievers rather than modifying core pipeline mechanisms.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
|
||||
Python (root) – requires Python >= 3.10 and < 3.14. We recommend `uv` for speed and reproducibility.
|
||||
|
||||
- Create/refresh env and install dev deps:
|
||||
```bash
|
||||
uv sync --dev --all-extras --reinstall
|
||||
```
|
||||
|
||||
- Run the CLI (examples):
|
||||
```bash
|
||||
uv run cognee-cli add "Cognee turns documents into AI memory."
|
||||
uv run cognee-cli cognify
|
||||
uv run cognee-cli search "What does cognee do?"
|
||||
uv run cognee-cli -ui # Launches UI, backend API, and MCP server together
|
||||
```
|
||||
|
||||
- Start the FastAPI server directly:
|
||||
```bash
|
||||
uv run python -m cognee.api.client
|
||||
```
|
||||
|
||||
- Run tests (CI mirrors these commands):
|
||||
```bash
|
||||
uv run pytest cognee/tests/unit/ -v
|
||||
uv run pytest cognee/tests/integration/ -v
|
||||
```
|
||||
|
||||
- Lint and format (ruff):
|
||||
```bash
|
||||
uv run ruff check .
|
||||
uv run ruff format .
|
||||
```
|
||||
|
||||
- Optional static type checks (mypy):
|
||||
```bash
|
||||
uv run mypy cognee/
|
||||
```
|
||||
|
||||
MCP Server (`cognee-mcp/`):
|
||||
|
||||
- Install and run locally:
|
||||
```bash
|
||||
cd cognee-mcp
|
||||
uv sync --dev --all-extras --reinstall
|
||||
uv run python src/server.py # stdio (default)
|
||||
uv run python src/server.py --transport sse
|
||||
uv run python src/server.py --transport http --host 127.0.0.1 --port 8000 --path /mcp
|
||||
```
|
||||
|
||||
- API Mode (connect to a running Cognee API):
|
||||
```bash
|
||||
uv run python src/server.py --transport sse --api-url http://localhost:8000 --api-token YOUR_TOKEN
|
||||
```
|
||||
|
||||
- Docker quickstart (examples): see `cognee-mcp/README.md` for full details
|
||||
```bash
|
||||
docker run -e TRANSPORT_MODE=http --env-file ./.env -p 8000:8000 --rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
Frontend (`cognee-frontend/`):
|
||||
```bash
|
||||
cd cognee-frontend
|
||||
npm install
|
||||
npm run dev # Next.js dev server
|
||||
npm run lint # ESLint
|
||||
npm run build && npm start
|
||||
```
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
|
||||
Python:
|
||||
- 4-space indentation, modules and functions in `snake_case`, classes in `PascalCase`.
|
||||
- Public APIs should be type-annotated where practical.
|
||||
- Use `ruff format` before committing; `ruff check` enforces import hygiene and style (line-length 100 configured in `pyproject.toml`).
|
||||
- Prefer explicit, structured error handling. Use shared logging utilities in `cognee.shared.logging_utils`.
|
||||
|
||||
MCP server and Frontend:
|
||||
- Follow the local `README.md` and ESLint/TypeScript configuration in `cognee-frontend/`.
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
- Place Python tests under `cognee/tests/`.
|
||||
- Unit tests: `cognee/tests/unit/`
|
||||
- Integration tests: `cognee/tests/integration/`
|
||||
- CLI tests: `cognee/tests/cli_tests/`
|
||||
- Name test files `test_*.py`. Use `pytest.mark.asyncio` for async tests.
|
||||
- Avoid external state; rely on test fixtures and the CI-provided env vars when LLM/embedding providers are required. See CI workflows under `.github/workflows/` for expected environment variables.
|
||||
- When adding public APIs, provide/update targeted examples under `examples/python/`.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
|
||||
- Use clear, imperative subjects (≤ 72 chars) and conventional commit styling in PR titles. Our CI validates semantic PR titles (see `.github/workflows/pr_lint`). Examples:
|
||||
- `feat(graph): add temporal edge weighting`
|
||||
- `fix(api): handle missing auth cookie`
|
||||
- `docs: update installation instructions`
|
||||
- Reference related issues/discussions in the PR body and provide brief context.
|
||||
- PRs should describe scope, list local test commands run, and mention any impacts on MCP server or UI if applicable.
|
||||
- Sign commits and affirm the DCO (see `CONTRIBUTING.md`).
|
||||
|
||||
## CI Mirrors Local Commands
|
||||
|
||||
Our GitHub Actions run the same ruff checks and pytest suites shown above (`.github/workflows/basic_tests.yml` and related workflows). Use the commands in this document locally to minimize CI surprises.
|
||||
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ Hosted platform:
|
|||
|
||||
### 📦 Installation
|
||||
|
||||
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager.
|
||||
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager..
|
||||
|
||||
Cognee supports Python 3.10 to 3.12
|
||||
|
||||
|
|
|
|||
|
|
@ -110,6 +110,47 @@ If you'd rather run cognee-mcp in a container, you have two options:
|
|||
# For stdio transport (default)
|
||||
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
**Installing optional dependencies at runtime:**
|
||||
|
||||
You can install optional dependencies when running the container by setting the `EXTRAS` environment variable:
|
||||
```bash
|
||||
# Install a single optional dependency group at runtime
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=http \
|
||||
-e EXTRAS=aws \
|
||||
--env-file ./.env \
|
||||
-p 8000:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
|
||||
# Install multiple optional dependency groups at runtime (comma-separated)
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=sse \
|
||||
-e EXTRAS=aws,postgres,neo4j \
|
||||
--env-file ./.env \
|
||||
-p 8000:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
**Available optional dependency groups:**
|
||||
- `aws` - S3 storage support
|
||||
- `postgres` / `postgres-binary` - PostgreSQL database support
|
||||
- `neo4j` - Neo4j graph database support
|
||||
- `neptune` - AWS Neptune support
|
||||
- `chromadb` - ChromaDB vector store support
|
||||
- `scraping` - Web scraping capabilities
|
||||
- `distributed` - Modal distributed execution
|
||||
- `langchain` - LangChain integration
|
||||
- `llama-index` - LlamaIndex integration
|
||||
- `anthropic` - Anthropic models
|
||||
- `groq` - Groq models
|
||||
- `mistral` - Mistral models
|
||||
- `ollama` / `huggingface` - Local model support
|
||||
- `docs` - Document processing
|
||||
- `codegraph` - Code analysis
|
||||
- `monitoring` - Sentry & Langfuse monitoring
|
||||
- `redis` - Redis support
|
||||
- And more (see [pyproject.toml](https://github.com/topoteretes/cognee/blob/main/pyproject.toml) for full list)
|
||||
2. **Pull from Docker Hub** (no build required):
|
||||
```bash
|
||||
# With HTTP transport (recommended for web deployments)
|
||||
|
|
@ -119,6 +160,17 @@ If you'd rather run cognee-mcp in a container, you have two options:
|
|||
# With stdio transport (default)
|
||||
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
**With runtime installation of optional dependencies:**
|
||||
```bash
|
||||
# Install optional dependencies from Docker Hub image
|
||||
docker run \
|
||||
-e TRANSPORT_MODE=http \
|
||||
-e EXTRAS=aws,postgres \
|
||||
--env-file ./.env \
|
||||
-p 8000:8000 \
|
||||
--rm -it cognee/cognee-mcp:main
|
||||
```
|
||||
|
||||
### **Important: Docker vs Direct Usage**
|
||||
**Docker uses environment variables**, not command line arguments:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,42 @@ set -e # Exit on error
|
|||
echo "Debug mode: $DEBUG"
|
||||
echo "Environment: $ENVIRONMENT"
|
||||
|
||||
# Install optional dependencies if EXTRAS is set
|
||||
if [ -n "$EXTRAS" ]; then
|
||||
echo "Installing optional dependencies: $EXTRAS"
|
||||
|
||||
# Get the cognee version that's currently installed
|
||||
COGNEE_VERSION=$(uv pip show cognee | grep "Version:" | awk '{print $2}')
|
||||
echo "Current cognee version: $COGNEE_VERSION"
|
||||
|
||||
# Build the extras list for cognee
|
||||
IFS=',' read -ra EXTRA_ARRAY <<< "$EXTRAS"
|
||||
# Combine base extras from pyproject.toml with requested extras
|
||||
ALL_EXTRAS=""
|
||||
for extra in "${EXTRA_ARRAY[@]}"; do
|
||||
# Trim whitespace
|
||||
extra=$(echo "$extra" | xargs)
|
||||
# Add to extras list if not already present
|
||||
if [[ ! "$ALL_EXTRAS" =~ (^|,)"$extra"(,|$) ]]; then
|
||||
if [ -z "$ALL_EXTRAS" ]; then
|
||||
ALL_EXTRAS="$extra"
|
||||
else
|
||||
ALL_EXTRAS="$ALL_EXTRAS,$extra"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Installing cognee with extras: $ALL_EXTRAS"
|
||||
echo "Running: uv pip install 'cognee[$ALL_EXTRAS]==$COGNEE_VERSION'"
|
||||
uv pip install "cognee[$ALL_EXTRAS]==$COGNEE_VERSION"
|
||||
|
||||
# Verify installation
|
||||
echo ""
|
||||
echo "✓ Optional dependencies installation completed"
|
||||
else
|
||||
echo "No optional dependencies specified"
|
||||
fi
|
||||
|
||||
# Set default transport mode if not specified
|
||||
TRANSPORT_MODE=${TRANSPORT_MODE:-"stdio"}
|
||||
echo "Transport mode: $TRANSPORT_MODE"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ dependencies = [
|
|||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4",
|
||||
"cognee[postgres,docs,neo4j]==0.3.7",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
|
|
|
|||
|
|
@ -37,12 +37,10 @@ async def run():
|
|||
|
||||
toolResult = await session.call_tool("prune", arguments={})
|
||||
|
||||
toolResult = await session.call_tool(
|
||||
"codify", arguments={"repo_path": "SOME_REPO_PATH"}
|
||||
)
|
||||
toolResult = await session.call_tool("cognify", arguments={})
|
||||
|
||||
toolResult = await session.call_tool(
|
||||
"search", arguments={"search_type": "CODE", "search_query": "exceptions"}
|
||||
"search", arguments={"search_type": "GRAPH_COMPLETION"}
|
||||
)
|
||||
|
||||
print(f"Cognify result: {toolResult.content}")
|
||||
|
|
|
|||
8575
cognee-mcp/uv.lock
generated
8575
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,8 +1,5 @@
|
|||
from uuid import UUID
|
||||
import os
|
||||
from typing import Union, BinaryIO, List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from urllib.parse import urlparse
|
||||
from typing import Union, BinaryIO, List, Optional, Any
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import Task, run_pipeline
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
|
||||
|
|
@ -17,16 +14,6 @@ from cognee.shared.logging_utils import get_logger
|
|||
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
|
||||
from cognee.context_global_variables import (
|
||||
tavily_config as tavily,
|
||||
soup_crawler_config as soup_crawler,
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
|
||||
|
||||
async def add(
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
|
|
@ -36,11 +23,8 @@ async def add(
|
|||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
dataset_id: Optional[UUID] = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
preferred_loaders: Optional[List[Union[str, dict[str, dict[str, Any]]]]] = None,
|
||||
incremental_loading: bool = True,
|
||||
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||
tavily_config: Optional[BaseModel] = None,
|
||||
soup_crawler_config: Optional[BaseModel] = None,
|
||||
data_per_batch: Optional[int] = 20,
|
||||
):
|
||||
"""
|
||||
|
|
@ -180,28 +164,14 @@ async def add(
|
|||
- TAVILY_API_KEY: YOUR_TAVILY_API_KEY
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
if not soup_crawler_config and extraction_rules:
|
||||
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
|
||||
if not tavily_config and os.getenv("TAVILY_API_KEY"):
|
||||
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
soup_crawler.set(soup_crawler_config)
|
||||
tavily.set(tavily_config)
|
||||
|
||||
http_schemes = {"http", "https"}
|
||||
|
||||
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
|
||||
return isinstance(item, str) and urlparse(item).scheme in http_schemes
|
||||
|
||||
if _is_http_url(data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
except NameError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
if preferred_loaders is not None:
|
||||
transformed = {}
|
||||
for item in preferred_loaders:
|
||||
if isinstance(item, dict):
|
||||
transformed.update(item)
|
||||
else:
|
||||
transformed[item] = {}
|
||||
preferred_loaders = transformed
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models import PipelineRunErrored
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -63,7 +64,11 @@ def get_add_router() -> APIRouter:
|
|||
send_telemetry(
|
||||
"Add API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/add",
|
||||
"node_set": node_set,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
|||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger("api.cognify")
|
||||
|
||||
|
|
@ -98,6 +98,7 @@ def get_cognify_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/cognify",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from uuid import UUID
|
||||
from cognee.modules.data.methods import has_dataset_data
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.ingestion import discover_directory_datasets
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
|
|
@ -26,6 +27,16 @@ class datasets:
|
|||
|
||||
return await get_dataset_data(dataset.id)
|
||||
|
||||
@staticmethod
|
||||
async def has_data(dataset_id: str) -> bool:
|
||||
from cognee.modules.data.methods import get_dataset
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
return await has_dataset_data(dataset.id)
|
||||
|
||||
@staticmethod
|
||||
async def get_status(dataset_ids: list[UUID]) -> dict:
|
||||
return await get_pipeline_status(dataset_ids, pipeline_name="cognify_pipeline")
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from cognee.modules.users.permissions.methods import (
|
|||
from cognee.modules.graph.methods import get_formatted_graph_data
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -100,6 +101,7 @@ def get_datasets_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /v1/datasets",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -147,6 +149,7 @@ def get_datasets_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/datasets",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -201,6 +204,7 @@ def get_datasets_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}",
|
||||
"dataset_id": str(dataset_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -246,6 +250,7 @@ def get_datasets_router() -> APIRouter:
|
|||
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}/data/{str(data_id)}",
|
||||
"dataset_id": str(dataset_id),
|
||||
"data_id": str(data_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -327,6 +332,7 @@ def get_datasets_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data",
|
||||
"dataset_id": str(dataset_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -387,6 +393,7 @@ def get_datasets_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": "GET /v1/datasets/status",
|
||||
"datasets": [str(dataset_id) for dataset_id in datasets],
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -433,6 +440,7 @@ def get_datasets_router() -> APIRouter:
|
|||
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data/{str(data_id)}/raw",
|
||||
"dataset_id": str(dataset_id),
|
||||
"data_id": str(data_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ def get_delete_router() -> APIRouter:
|
|||
"endpoint": "DELETE /v1/delete",
|
||||
"dataset_id": str(dataset_id),
|
||||
"data_id": str(data_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models import PipelineRunErrored
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -73,7 +74,7 @@ def get_memify_router() -> APIRouter:
|
|||
send_telemetry(
|
||||
"Memify API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={"endpoint": "POST /v1/memify"},
|
||||
additional_properties={"endpoint": "POST /v1/memify", "cognee_version": cognee_version},
|
||||
)
|
||||
|
||||
if not payload.dataset_id and not payload.dataset_name:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
|
|
@ -48,6 +49,7 @@ def get_permissions_router() -> APIRouter:
|
|||
"endpoint": f"POST /v1/permissions/datasets/{str(principal_id)}",
|
||||
"dataset_ids": str(dataset_ids),
|
||||
"principal_id": str(principal_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -89,6 +91,7 @@ def get_permissions_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": "POST /v1/permissions/roles",
|
||||
"role_name": role_name,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -133,6 +136,7 @@ def get_permissions_router() -> APIRouter:
|
|||
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/roles",
|
||||
"user_id": str(user_id),
|
||||
"role_id": str(role_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -175,6 +179,7 @@ def get_permissions_router() -> APIRouter:
|
|||
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/tenants",
|
||||
"user_id": str(user_id),
|
||||
"tenant_id": str(tenant_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -209,6 +214,7 @@ def get_permissions_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": "POST /v1/permissions/tenants",
|
||||
"tenant_name": tenant_name,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
|
||||
# Note: Datasets sent by name will only map to datasets owned by the request sender
|
||||
|
|
@ -61,9 +62,7 @@ def get_search_router() -> APIRouter:
|
|||
send_telemetry(
|
||||
"Search API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /v1/search",
|
||||
},
|
||||
additional_properties={"endpoint": "GET /v1/search", "cognee_version": cognee_version},
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -118,6 +117,7 @@ def get_search_router() -> APIRouter:
|
|||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
"use_combined_context": payload.use_combined_context,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Optional, List, Type
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
|
|
@ -8,6 +9,10 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
from cognee.context_global_variables import set_session_user_context_variable
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -25,6 +30,7 @@ async def search(
|
|||
last_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -113,6 +119,8 @@ async def search(
|
|||
|
||||
save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not
|
||||
|
||||
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
|
||||
|
||||
Returns:
|
||||
list: Search results in format determined by query_type:
|
||||
|
||||
|
|
@ -168,6 +176,8 @@ async def search(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
await set_session_user_context_variable(user)
|
||||
|
||||
# Transform string based datasets to UUID - String based datasets can only be found for current user
|
||||
if datasets is not None and [all(isinstance(dataset, str) for dataset in datasets)]:
|
||||
datasets = await get_authorized_existing_datasets(datasets, "read", user)
|
||||
|
|
@ -189,6 +199,7 @@ async def search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from cognee.modules.sync.methods import get_running_sync_operations_for_user, ge
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.api.v1.sync import SyncResponse
|
||||
from cognee import __version__ as cognee_version
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -99,6 +100,7 @@ def get_sync_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/sync",
|
||||
"cognee_version": cognee_version,
|
||||
"dataset_ids": [str(id) for id in request.dataset_ids]
|
||||
if request.dataset_ids
|
||||
else "*",
|
||||
|
|
@ -205,6 +207,7 @@ def get_sync_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /v1/sync/status",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -503,7 +503,7 @@ def start_ui(
|
|||
if start_mcp:
|
||||
logger.info("Starting Cognee MCP server with Docker...")
|
||||
try:
|
||||
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
|
||||
image = "cognee/cognee-mcp:main"
|
||||
subprocess.run(["docker", "pull", image], check=True)
|
||||
|
||||
import uuid
|
||||
|
|
@ -538,9 +538,7 @@ def start_ui(
|
|||
env_file = os.path.join(cwd, ".env")
|
||||
docker_cmd.extend(["--env-file", env_file])
|
||||
|
||||
docker_cmd.append(
|
||||
image
|
||||
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
|
||||
docker_cmd.append(image)
|
||||
|
||||
mcp_process = subprocess.Popen(
|
||||
docker_cmd,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunErrored,
|
||||
)
|
||||
|
|
@ -64,6 +65,7 @@ def get_update_router() -> APIRouter:
|
|||
"dataset_id": str(dataset_id),
|
||||
"data_id": str(data_id),
|
||||
"node_set": str(node_set),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, List, Optional
|
||||
from typing import Union, BinaryIO, List, Optional, Any
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.api.v1.delete import delete
|
||||
|
|
@ -15,7 +15,7 @@ async def update(
|
|||
node_set: Optional[List[str]] = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
preferred_loaders: dict[str, dict[str, Any]] = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.modules.users.models import User
|
|||
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -46,6 +47,7 @@ def get_visualize_router() -> APIRouter:
|
|||
additional_properties={
|
||||
"endpoint": "GET /v1/visualize",
|
||||
"dataset_id": str(dataset_id),
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from cognee.root_dir import get_absolute_path, ensure_absolute_path
|
||||
|
|
@ -11,6 +12,9 @@ class BaseConfig(BaseSettings):
|
|||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||
cache_root_directory: str = get_absolute_path(".cognee_cache")
|
||||
logs_root_directory: str = os.getenv(
|
||||
"COGNEE_LOGS_DIR", str(os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs"))
|
||||
)
|
||||
monitoring_tool: object = Observer.NONE
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
|
|
@ -30,6 +34,8 @@ class BaseConfig(BaseSettings):
|
|||
# Require absolute paths for root directories
|
||||
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
||||
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
||||
self.logs_root_directory = ensure_absolute_path(self.logs_root_directory)
|
||||
|
||||
# Set monitoring tool based on available keys
|
||||
if self.langfuse_public_key and self.langfuse_secret_key:
|
||||
self.monitoring_tool = Observer.LANGFUSE
|
||||
|
|
@ -49,6 +55,7 @@ class BaseConfig(BaseSettings):
|
|||
"system_root_directory": self.system_root_directory,
|
||||
"monitoring_tool": self.monitoring_tool,
|
||||
"cache_root_directory": self.cache_root_directory,
|
||||
"logs_root_directory": self.logs_root_directory,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,11 @@ from cognee.modules.users.methods import get_user
|
|||
# for different async tasks, threads and processes
|
||||
vector_db_config = ContextVar("vector_db_config", default=None)
|
||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
|
||||
tavily_config = ContextVar("tavily_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class CogneeValidationError(CogneeApiError):
|
|||
self,
|
||||
message: str = "A validation error occurred.",
|
||||
name: str = "CogneeValidationError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
log=True,
|
||||
log_level="ERROR",
|
||||
):
|
||||
|
|
|
|||
|
|
@ -40,3 +40,40 @@ class CacheDBInterface(ABC):
|
|||
yield
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
@abstractmethod
|
||||
async def add_qa(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
):
|
||||
"""
|
||||
Add a Q/A/context triplet to a cache session.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
|
||||
"""
|
||||
Retrieve the most recent Q/A/context triplets for a session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_qas(self, user_id: str, session_id: str):
|
||||
"""
|
||||
Retrieve all Q/A/context triplets for the given session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""
|
||||
Gracefully close any async connections.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CacheConfig(BaseSettings):
|
||||
|
|
@ -18,6 +19,8 @@ class CacheConfig(BaseSettings):
|
|||
shared_kuzu_lock: bool = False
|
||||
cache_host: str = "localhost"
|
||||
cache_port: int = 6379
|
||||
cache_username: Optional[str] = None
|
||||
cache_password: Optional[str] = None
|
||||
agentic_lock_expire: int = 240
|
||||
agentic_lock_timeout: int = 300
|
||||
|
||||
|
|
@ -29,6 +32,8 @@ class CacheConfig(BaseSettings):
|
|||
"shared_kuzu_lock": self.shared_kuzu_lock,
|
||||
"cache_host": self.cache_host,
|
||||
"cache_port": self.cache_port,
|
||||
"cache_username": self.cache_username,
|
||||
"cache_password": self.cache_password,
|
||||
"agentic_lock_expire": self.agentic_lock_expire,
|
||||
"agentic_lock_timeout": self.agentic_lock_timeout,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
|
||||
config = get_cache_config()
|
||||
|
|
@ -12,6 +12,8 @@ config = get_cache_config()
|
|||
def create_cache_engine(
|
||||
cache_host: str,
|
||||
cache_port: int,
|
||||
cache_username: str,
|
||||
cache_password: str,
|
||||
lock_key: str,
|
||||
agentic_lock_expire: int = 240,
|
||||
agentic_lock_timeout: int = 300,
|
||||
|
|
@ -23,6 +25,8 @@ def create_cache_engine(
|
|||
-----------
|
||||
- cache_host: Hostname or IP of the cache server.
|
||||
- cache_port: Port number to connect to.
|
||||
- cache_username: Username to authenticate with.
|
||||
- cache_password: Password to authenticate with.
|
||||
- lock_key: Identifier used for the locking resource.
|
||||
- agentic_lock_expire: Duration to hold the lock after acquisition.
|
||||
- agentic_lock_timeout: Max time to wait for the lock before failing.
|
||||
|
|
@ -37,6 +41,8 @@ def create_cache_engine(
|
|||
return RedisAdapter(
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
|
|
@ -45,7 +51,7 @@ def create_cache_engine(
|
|||
return None
|
||||
|
||||
|
||||
def get_cache_engine(lock_key: str) -> CacheDBInterface:
|
||||
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
|
||||
"""
|
||||
Returns a cache adapter instance using current context configuration.
|
||||
"""
|
||||
|
|
@ -53,6 +59,8 @@ def get_cache_engine(lock_key: str) -> CacheDBInterface:
|
|||
return create_cache_engine(
|
||||
cache_host=config.cache_host,
|
||||
cache_port=config.cache_port,
|
||||
cache_username=config.cache_username,
|
||||
cache_password=config.cache_password,
|
||||
lock_key=lock_key,
|
||||
agentic_lock_expire=config.agentic_lock_expire,
|
||||
agentic_lock_timeout=config.agentic_lock_timeout,
|
||||
|
|
|
|||
|
|
@ -1,20 +1,81 @@
|
|||
import asyncio
|
||||
import redis
|
||||
import redis.asyncio as aioredis
|
||||
from contextlib import contextmanager
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
logger = get_logger("RedisAdapter")
|
||||
|
||||
|
||||
class RedisAdapter(CacheDBInterface):
|
||||
def __init__(self, host, port, lock_name, timeout=240, blocking_timeout=300):
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
lock_name="default_lock",
|
||||
username=None,
|
||||
password=None,
|
||||
timeout=240,
|
||||
blocking_timeout=300,
|
||||
connection_timeout=30,
|
||||
):
|
||||
super().__init__(host, port, lock_name)
|
||||
self.redis = redis.Redis(host=host, port=port)
|
||||
self.timeout = timeout
|
||||
self.blocking_timeout = blocking_timeout
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.connection_timeout = connection_timeout
|
||||
|
||||
try:
|
||||
self.sync_redis = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
socket_connect_timeout=connection_timeout,
|
||||
socket_timeout=connection_timeout,
|
||||
)
|
||||
self.async_redis = aioredis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=connection_timeout,
|
||||
)
|
||||
self.timeout = timeout
|
||||
self.blocking_timeout = blocking_timeout
|
||||
|
||||
# Validate connection on initialization
|
||||
self._validate_connection()
|
||||
logger.info(f"Successfully connected to Redis at {host}:{port}")
|
||||
|
||||
except (redis.ConnectionError, redis.TimeoutError) as e:
|
||||
error_msg = f"Failed to connect to Redis at {host}:{port}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error initializing Redis adapter: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
def _validate_connection(self):
|
||||
"""Validate Redis connection is available."""
|
||||
try:
|
||||
self.sync_redis.ping()
|
||||
except (redis.ConnectionError, redis.TimeoutError) as e:
|
||||
raise CacheConnectionError(
|
||||
f"Cannot connect to Redis at {self.host}:{self.port}: {str(e)}"
|
||||
) from e
|
||||
|
||||
def acquire_lock(self):
|
||||
"""
|
||||
Acquire the Redis lock manually. Raises if acquisition fails.
|
||||
Acquire the Redis lock manually. Raises if acquisition fails. (Sync because of Kuzu)
|
||||
"""
|
||||
self.lock = self.redis.lock(
|
||||
self.lock = self.sync_redis.lock(
|
||||
name=self.lock_key,
|
||||
timeout=self.timeout,
|
||||
blocking_timeout=self.blocking_timeout,
|
||||
|
|
@ -28,7 +89,7 @@ class RedisAdapter(CacheDBInterface):
|
|||
|
||||
def release_lock(self):
|
||||
"""
|
||||
Release the Redis lock manually, if held.
|
||||
Release the Redis lock manually, if held. (Sync because of Kuzu)
|
||||
"""
|
||||
if self.lock:
|
||||
try:
|
||||
|
|
@ -40,10 +101,143 @@ class RedisAdapter(CacheDBInterface):
|
|||
@contextmanager
|
||||
def hold_lock(self):
|
||||
"""
|
||||
Context manager for acquiring and releasing the Redis lock automatically.
|
||||
Context manager for acquiring and releasing the Redis lock automatically. (Sync because of Kuzu)
|
||||
"""
|
||||
self.acquire()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
async def add_qa(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
):
|
||||
"""
|
||||
Add a Q/A/context triplet to a Redis list for this session.
|
||||
Creates the session if it doesn't exist.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
session_id: Unique identifier for the session.
|
||||
question: User question text.
|
||||
context: Context used to answer.
|
||||
answer: Assistant answer text.
|
||||
ttl: Optional time-to-live (seconds). If provided, the session expires after this time.
|
||||
|
||||
Raises:
|
||||
CacheConnectionError: If Redis connection fails or times out.
|
||||
"""
|
||||
try:
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
|
||||
qa_entry = {
|
||||
"time": datetime.utcnow().isoformat(),
|
||||
"question": question,
|
||||
"context": context,
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
await self.async_redis.rpush(session_key, json.dumps(qa_entry))
|
||||
|
||||
if ttl is not None:
|
||||
await self.async_redis.expire(session_key, ttl)
|
||||
|
||||
except (redis.ConnectionError, redis.TimeoutError) as e:
|
||||
error_msg = f"Redis connection error while adding Q&A: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while adding Q&A to Redis: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
|
||||
"""
|
||||
Retrieve the most recent Q/A/context triplet(s) for the given session.
|
||||
"""
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
if last_n == 1:
|
||||
data = await self.async_redis.lindex(session_key, -1)
|
||||
return [json.loads(data)] if data else None
|
||||
else:
|
||||
data = await self.async_redis.lrange(session_key, -last_n, -1)
|
||||
return [json.loads(d) for d in data] if data else []
|
||||
|
||||
async def get_all_qas(self, user_id: str, session_id: str):
|
||||
"""
|
||||
Retrieve all Q/A/context triplets for the given session.
|
||||
"""
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
entries = await self.async_redis.lrange(session_key, 0, -1)
|
||||
return [json.loads(e) for e in entries]
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Gracefully close the async Redis connection.
|
||||
"""
|
||||
await self.async_redis.aclose()
|
||||
|
||||
|
||||
async def main():
|
||||
HOST = "localhost"
|
||||
PORT = 6379
|
||||
|
||||
adapter = RedisAdapter(host=HOST, port=PORT)
|
||||
session_id = "demo_session"
|
||||
user_id = "demo_user_id"
|
||||
|
||||
print("\nAdding sample Q/A pairs...")
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Basic DB context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"Historical context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print("\nLatest QA:")
|
||||
latest = await adapter.get_latest_qa(user_id, session_id)
|
||||
print(json.dumps(latest, indent=2))
|
||||
|
||||
print("\nLast 2 QAs:")
|
||||
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
|
||||
print(json.dumps(last_two, indent=2))
|
||||
|
||||
session_id = "session_expire_demo"
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Database context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"History context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print(await adapter.get_all_qas(user_id, session_id))
|
||||
|
||||
await adapter.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -11,4 +11,5 @@ from .exceptions import (
|
|||
EmbeddingException,
|
||||
MissingQueryParameterError,
|
||||
MutuallyExclusiveQueryParametersError,
|
||||
CacheConnectionError,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class DatabaseNotCreatedError(CogneeSystemError):
|
|||
self,
|
||||
message: str = "The database has not been created yet. Please call `await setup()` first.",
|
||||
name: str = "DatabaseNotCreatedError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class EmbeddingException(CogneeConfigurationError):
|
|||
self,
|
||||
message: str = "Embedding Exception.",
|
||||
name: str = "EmbeddingException",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
|
@ -132,3 +132,19 @@ class MutuallyExclusiveQueryParametersError(CogneeValidationError):
|
|||
):
|
||||
message = "The search function accepts either text or embedding as input, but not both."
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class CacheConnectionError(CogneeConfigurationError):
|
||||
"""
|
||||
Raised when connection to the cache database (e.g., Redis) fails.
|
||||
|
||||
This error indicates that the cache service is unavailable or misconfigured.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to connect to cache database. Please check your cache configuration.",
|
||||
name: str = "CacheConnectionError",
|
||||
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -159,6 +159,11 @@ class GraphDBInterface(ABC):
|
|||
- get_connections
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def is_empty(self) -> bool:
|
||||
logger.warning("is_empty() is not implemented")
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, params: dict) -> List[Any]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -198,6 +198,15 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except FileNotFoundError:
|
||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
query = """
|
||||
MATCH (n)
|
||||
RETURN true
|
||||
LIMIT 1;
|
||||
"""
|
||||
query_result = await self.query(query)
|
||||
return len(query_result) == 0
|
||||
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""
|
||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||
|
|
@ -1357,9 +1366,15 @@ class KuzuAdapter(GraphDBInterface):
|
|||
params[param_name] = values
|
||||
|
||||
where_clause = " AND ".join(where_clauses)
|
||||
nodes_query = (
|
||||
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
|
||||
)
|
||||
nodes_query = f"""
|
||||
MATCH (n:Node)
|
||||
WHERE {where_clause}
|
||||
RETURN n.id, {{
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}}
|
||||
"""
|
||||
edges_query = f"""
|
||||
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
|
||||
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
|
||||
|
|
|
|||
|
|
@ -87,6 +87,15 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async with self.driver.session(database=self.graph_database_name) as session:
|
||||
yield session
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
query = """
|
||||
RETURN EXISTS {
|
||||
MATCH (n)
|
||||
} AS node_exists;
|
||||
"""
|
||||
query_result = await self.query(query)
|
||||
return not query_result[0]["node_exists"]
|
||||
|
||||
@deadlock_retry()
|
||||
async def query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def create_vector_engine(
|
|||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
if vector_db_provider == "pgvector":
|
||||
if vector_db_provider.lower() == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
# Get configuration for postgres database
|
||||
|
|
@ -78,7 +78,7 @@ def create_vector_engine(
|
|||
embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "chromadb":
|
||||
elif vector_db_provider.lower() == "chromadb":
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError:
|
||||
|
|
@ -94,7 +94,7 @@ def create_vector_engine(
|
|||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "neptune_analytics":
|
||||
elif vector_db_provider.lower() == "neptune_analytics":
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
|
|
@ -122,7 +122,7 @@ def create_vector_engine(
|
|||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
elif vector_db_provider.lower() == "lancedb":
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
return LanceDBAdapter(
|
||||
|
|
@ -130,3 +130,9 @@ def create_vector_engine(
|
|||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Unsupported vector database provider: {vector_db_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return data["embedding"]
|
||||
return data["embeddings"][0]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class CollectionNotFoundError(CogneeValidationError):
|
|||
self,
|
||||
message,
|
||||
name: str = "CollectionNotFoundError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
log=True,
|
||||
log_level="DEBUG",
|
||||
):
|
||||
|
|
|
|||
|
|
@ -324,7 +324,6 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
if hasattr(field_config, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
|
|
@ -18,9 +18,21 @@ class Edge(BaseModel):
|
|||
|
||||
# Mixed usage
|
||||
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
|
||||
|
||||
# With edge_text for rich embedding representation
|
||||
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
|
||||
"""
|
||||
|
||||
weight: Optional[float] = None
|
||||
weights: Optional[Dict[str, float]] = None
|
||||
relationship_type: Optional[str] = None
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
edge_text: Optional[str] = None
|
||||
|
||||
@field_validator("edge_text", mode="before")
|
||||
@classmethod
|
||||
def ensure_edge_text(cls, v, info):
|
||||
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
|
||||
if v is None and info.data.get("relationship_type"):
|
||||
return info.data["relationship_type"]
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@ class FileContentHashingError(Exception):
|
|||
self,
|
||||
message: str = "Failed to hash content of the file.",
|
||||
name: str = "FileContentHashingError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -82,16 +82,16 @@ class LocalFileStorage(Storage):
|
|||
self.ensure_directory_exists(file_dir_path)
|
||||
|
||||
if overwrite or not os.path.exists(full_file_path):
|
||||
with open(
|
||||
full_file_path,
|
||||
mode="w" if isinstance(data, str) else "wb",
|
||||
encoding="utf-8" if isinstance(data, str) else None,
|
||||
) as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
with open(full_file_path, mode="w", encoding="utf-8", newline="\n") as file:
|
||||
file.write(data)
|
||||
else:
|
||||
with open(full_file_path, mode="wb") as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
file.write(data)
|
||||
|
||||
file.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -70,18 +70,18 @@ class S3FileStorage(Storage):
|
|||
if overwrite or not await self.file_exists(file_path):
|
||||
|
||||
def save_data_to_file():
|
||||
with self.s3.open(
|
||||
full_file_path,
|
||||
mode="w" if isinstance(data, str) else "wb",
|
||||
encoding="utf-8" if isinstance(data, str) else None,
|
||||
) as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
with self.s3.open(
|
||||
full_file_path, mode="w", encoding="utf-8", newline="\n"
|
||||
) as file:
|
||||
file.write(data)
|
||||
|
||||
file.close()
|
||||
else:
|
||||
with self.s3.open(full_file_path, mode="wb") as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
file.write(data)
|
||||
|
||||
await run_async(save_data_to_file)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import io
|
||||
import os.path
|
||||
from typing import BinaryIO, TypedDict
|
||||
from typing import BinaryIO, TypedDict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -27,7 +27,7 @@ class FileMetadata(TypedDict):
|
|||
file_size: int
|
||||
|
||||
|
||||
async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||
async def get_file_metadata(file: BinaryIO, name: Optional[str] = None) -> FileMetadata:
|
||||
"""
|
||||
Retrieve metadata from a file object.
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
|||
except io.UnsupportedOperation as error:
|
||||
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")
|
||||
|
||||
file_type = guess_file_type(file)
|
||||
file_type = guess_file_type(file, name)
|
||||
|
||||
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from typing import BinaryIO
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Optional, Any
|
||||
import filetype
|
||||
from .is_text_content import is_text_content
|
||||
from tempfile import SpooledTemporaryFile
|
||||
from filetype.types.base import Type
|
||||
|
||||
|
||||
class FileTypeException(Exception):
|
||||
|
|
@ -22,90 +25,7 @@ class FileTypeException(Exception):
|
|||
self.message = message
|
||||
|
||||
|
||||
class TxtFileType(filetype.Type):
|
||||
"""
|
||||
Represents a text file type with specific MIME and extension properties.
|
||||
|
||||
Public methods:
|
||||
- match: Determines whether a given buffer matches the text file type.
|
||||
"""
|
||||
|
||||
MIME = "text/plain"
|
||||
EXTENSION = "txt"
|
||||
|
||||
def __init__(self):
|
||||
super(TxtFileType, self).__init__(mime=TxtFileType.MIME, extension=TxtFileType.EXTENSION)
|
||||
|
||||
def match(self, buf):
|
||||
"""
|
||||
Determine if the given buffer contains text content.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- buf: The buffer to check for text content.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns True if the buffer is identified as text content, otherwise False.
|
||||
"""
|
||||
return is_text_content(buf)
|
||||
|
||||
|
||||
txt_file_type = TxtFileType()
|
||||
|
||||
filetype.add_type(txt_file_type)
|
||||
|
||||
|
||||
class CustomPdfMatcher(filetype.Type):
|
||||
"""
|
||||
Match PDF file types based on MIME type and extension.
|
||||
|
||||
Public methods:
|
||||
- match
|
||||
|
||||
Instance variables:
|
||||
- MIME: The MIME type of the PDF.
|
||||
- EXTENSION: The file extension of the PDF.
|
||||
"""
|
||||
|
||||
MIME = "application/pdf"
|
||||
EXTENSION = "pdf"
|
||||
|
||||
def __init__(self):
|
||||
super(CustomPdfMatcher, self).__init__(
|
||||
mime=CustomPdfMatcher.MIME, extension=CustomPdfMatcher.EXTENSION
|
||||
)
|
||||
|
||||
def match(self, buf):
|
||||
"""
|
||||
Determine if the provided buffer is a PDF file.
|
||||
|
||||
This method checks for the presence of the PDF signature in the buffer.
|
||||
|
||||
Raises:
|
||||
- TypeError: If the buffer is not of bytes type.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- buf: The buffer containing the data to be checked.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns True if the buffer contains a PDF signature, otherwise returns False.
|
||||
"""
|
||||
return b"PDF-" in buf
|
||||
|
||||
|
||||
custom_pdf_matcher = CustomPdfMatcher()
|
||||
|
||||
filetype.add_type(custom_pdf_matcher)
|
||||
|
||||
|
||||
def guess_file_type(file: BinaryIO) -> filetype.Type:
|
||||
def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type:
|
||||
"""
|
||||
Guess the file type from the given binary file stream.
|
||||
|
||||
|
|
@ -122,12 +42,23 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
|
|||
|
||||
- filetype.Type: The guessed file type, represented as filetype.Type.
|
||||
"""
|
||||
|
||||
# Note: If file has .txt or .text extension, consider it a plain text file as filetype.guess may not detect it properly
|
||||
# as it contains no magic number encoding
|
||||
ext = None
|
||||
if isinstance(file, str):
|
||||
ext = Path(file).suffix
|
||||
elif name is not None:
|
||||
ext = Path(name).suffix
|
||||
|
||||
if ext in [".txt", ".text"]:
|
||||
file_type = Type("text/plain", "txt")
|
||||
return file_type
|
||||
|
||||
file_type = filetype.guess(file)
|
||||
|
||||
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
||||
if file_type is None:
|
||||
from filetype.types.base import Type
|
||||
|
||||
file_type = Type("text/plain", "txt")
|
||||
|
||||
if file_type is None:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
For the purposes of identifying timestamps in a query, you are tasked with extracting relevant timestamps from the query.
|
||||
## Timestamp requirements
|
||||
- If the query contains interval extrack both starts_at and ends_at properties
|
||||
- If the query contains an instantaneous timestamp, starts_at and ends_at should be the same
|
||||
- If the query its open-ended (before 2009 or after 2009), the corresponding non defined end of the time should be none
|
||||
-For example: "before 2009" -- starts_at: None, ends_at: 2009 or "after 2009" -- starts_at: 2009, ends_at: None
|
||||
- Put always the data that comes first in time as starts_at and the timestamps that comes second in time as ends_at
|
||||
- If starts_at or ends_at cannot be extracted both of them has to be None
|
||||
## Output Format
|
||||
Your reply should be a JSON: list of dictionaries with the following structure:
|
||||
```python
|
||||
class QueryInterval(BaseModel):
|
||||
starts_at: Optional[Timestamp] = None
|
||||
ends_at: Optional[Timestamp] = None
|
||||
```
|
||||
You are tasked with identifying relevant time periods where the answer to a given query should be searched.
|
||||
Current date is: `{{ time_now }}`. Determine relevant period(s) and return structured intervals.
|
||||
|
||||
Extraction rules:
|
||||
|
||||
1. Query without specific timestamp: use the time period with starts_at set to None and ends_at set to now.
|
||||
2. Explicit time intervals: If the query specifies a range (e.g., from 2010 to 2020, between January and March 2023), extract both start and end dates. Always assign the earlier date to starts_at and the later date to ends_at.
|
||||
3. Single timestamp: If the query refers to one specific moment (e.g., in 2015, on March 5, 2022), set starts_at and ends_at to that same timestamp.
|
||||
4. Open-ended time references: For phrases such as "before X" or "after X", represent the unspecified side as None. For example: before 2009 → starts_at: None, ends_at: 2009; after 2009 → starts_at: 2009, ends_at: None.
|
||||
5. Current-time references ("now", "current", "today"): If the query explicitly refers to the present, set both starts_at and ends_at to now (the ingestion timestamp).
|
||||
6. "Who is" and "Who was" questions: These imply a general identity or biographical inquiry without a specific temporal scope. Set both starts_at and ends_at to None.
|
||||
7. Ordering rule: Always ensure the earlier date is assigned to starts_at and the later date to ends_at.
|
||||
8. No temporal information: If no valid or inferable time reference is found, set both starts_at and ends_at to None.
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
A question was previously answered, but the answer received negative feedback.
|
||||
Please reconsider and improve the response.
|
||||
|
||||
Question: {question}
|
||||
Context originally used: {context}
|
||||
Previous answer: {wrong_answer}
|
||||
Feedback on that answer: {negative_feedback}
|
||||
|
||||
Task: Provide a better response. The new answer should be short and direct.
|
||||
Then explain briefly why this answer is better.
|
||||
|
||||
Format your reply as:
|
||||
Answer: <improved answer>
|
||||
Explanation: <short explanation>
|
||||
13
cognee/infrastructure/llm/prompts/feedback_report_prompt.txt
Normal file
13
cognee/infrastructure/llm/prompts/feedback_report_prompt.txt
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
|
||||
The paragraph should read naturally on its own, providing all necessary context and reasoning
|
||||
so the answer is clear and well-supported.
|
||||
|
||||
Question: {question}
|
||||
Correct answer: {improved_answer}
|
||||
Supporting context: {new_context}
|
||||
|
||||
Your paragraph should:
|
||||
- First sentence clearly states the correct answer as a full sentence
|
||||
- Remainder flows from first sentence and provides explanation based on context
|
||||
- Use simple, direct language that is easy to follow
|
||||
- Use shorter sentences, no long-winded explanations
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
Question: {question}
|
||||
Context: {context}
|
||||
|
||||
Provide a one paragraph human readable summary of this interaction context,
|
||||
listing all the relevant facts and information in a simple and direct way.
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import filetype
|
||||
from typing import Dict, List, Optional, Any
|
||||
from .LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -64,7 +65,9 @@ class LoaderEngine:
|
|||
return True
|
||||
|
||||
def get_loader(
|
||||
self, file_path: str, preferred_loaders: List[str] = None
|
||||
self,
|
||||
file_path: str,
|
||||
preferred_loaders: dict[str, dict[str, Any]],
|
||||
) -> Optional[LoaderInterface]:
|
||||
"""
|
||||
Get appropriate loader for a file.
|
||||
|
|
@ -76,14 +79,21 @@ class LoaderEngine:
|
|||
Returns:
|
||||
LoaderInterface that can handle the file, or None if not found
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
file_info = filetype.guess(file_path)
|
||||
file_info = guess_file_type(file_path)
|
||||
|
||||
path_extension = Path(file_path).suffix.lstrip(".")
|
||||
|
||||
# Try preferred loaders first
|
||||
if preferred_loaders:
|
||||
for loader_name in preferred_loaders:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
# Try with path extension first (for text formats like html)
|
||||
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
# Fall back to content-detected extension
|
||||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
|
|
@ -93,6 +103,10 @@ class LoaderEngine:
|
|||
for loader_name in self.default_loader_priority:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
# Try with path extension first (for text formats like html)
|
||||
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
# Fall back to content-detected extension
|
||||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
|
|
@ -105,7 +119,7 @@ class LoaderEngine:
|
|||
async def load_file(
|
||||
self,
|
||||
file_path: str,
|
||||
preferred_loaders: Optional[List[str]] = None,
|
||||
preferred_loaders: dict[str, dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -113,7 +127,7 @@ class LoaderEngine:
|
|||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
preferred_loaders: Dict of loader names to their configurations
|
||||
**kwargs: Additional loader-specific configuration
|
||||
|
||||
Raises:
|
||||
|
|
@ -125,8 +139,16 @@ class LoaderEngine:
|
|||
raise ValueError(f"No loader found for file: {file_path}")
|
||||
|
||||
logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
||||
# TODO: loading needs to be reworked to work with both file streams and file locations
|
||||
return await loader.load(file_path, **kwargs)
|
||||
|
||||
# Extract loader-specific config from preferred_loaders
|
||||
loader_config = {}
|
||||
if preferred_loaders and loader.loader_name in preferred_loaders:
|
||||
loader_config = preferred_loaders[loader.loader_name]
|
||||
|
||||
# Merge with any additional kwargs (kwargs take precedence)
|
||||
merged_kwargs = {**loader_config, **kwargs}
|
||||
|
||||
return await loader.load(file_path, **merged_kwargs)
|
||||
|
||||
def get_available_loaders(self) -> List[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class AudioLoader(LoaderInterface):
|
|||
"audio/wav",
|
||||
"audio/amr",
|
||||
"audio/aiff",
|
||||
"audio/x-wav",
|
||||
]
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -27,3 +27,10 @@ try:
|
|||
__all__.append("AdvancedPdfLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .beautiful_soup_loader import BeautifulSoupLoader
|
||||
|
||||
__all__.append("BeautifulSoupLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
310
cognee/infrastructure/loaders/external/beautiful_soup_loader.py
vendored
Normal file
310
cognee/infrastructure/loaders/external/beautiful_soup_loader.py
vendored
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""BeautifulSoup-based web crawler for extracting content from web pages.
|
||||
|
||||
This module provides the BeautifulSoupCrawler class for fetching and extracting content
|
||||
from web pages using BeautifulSoup or Playwright for JavaScript-rendered pages. It
|
||||
supports robots.txt handling, rate limiting, and custom extraction rules.
|
||||
"""
|
||||
|
||||
from typing import Union, Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from bs4 import BeautifulSoup
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionRule:
|
||||
"""Normalized extraction rule for web content.
|
||||
|
||||
Attributes:
|
||||
selector: CSS selector for extraction (if any).
|
||||
xpath: XPath expression for extraction (if any).
|
||||
attr: HTML attribute to extract (if any).
|
||||
all: If True, extract all matching elements; otherwise, extract first.
|
||||
join_with: String to join multiple extracted elements.
|
||||
"""
|
||||
|
||||
selector: Optional[str] = None
|
||||
xpath: Optional[str] = None
|
||||
attr: Optional[str] = None
|
||||
all: bool = False
|
||||
join_with: str = " "
|
||||
|
||||
|
||||
class BeautifulSoupLoader(LoaderInterface):
|
||||
"""Crawler for fetching and extracting web content using BeautifulSoup.
|
||||
|
||||
Supports asynchronous HTTP requests, Playwright for JavaScript rendering, robots.txt
|
||||
compliance, and rate limiting. Extracts content using CSS selectors or XPath rules.
|
||||
|
||||
Attributes:
|
||||
concurrency: Number of concurrent requests allowed.
|
||||
crawl_delay: Minimum seconds between requests to the same domain.
|
||||
max_crawl_delay: Maximum crawl delay to respect from robots.txt (None = no limit).
|
||||
timeout: Per-request timeout in seconds.
|
||||
max_retries: Number of retries for failed requests.
|
||||
retry_delay_factor: Multiplier for exponential backoff on retries.
|
||||
headers: HTTP headers for requests (e.g., User-Agent).
|
||||
robots_cache_ttl: Time-to-live for robots.txt cache in seconds.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return ["html"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return ["text/html", "text/plain"]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "beautiful_soup_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
can = extension in self.supported_extensions and mime_type in self.supported_mime_types
|
||||
return can
|
||||
|
||||
def _get_default_extraction_rules(self):
|
||||
# Comprehensive default extraction rules for common HTML content
|
||||
return {
|
||||
# Meta information
|
||||
"title": {"selector": "title", "all": False},
|
||||
"meta_description": {
|
||||
"selector": "meta[name='description']",
|
||||
"attr": "content",
|
||||
"all": False,
|
||||
},
|
||||
"meta_keywords": {
|
||||
"selector": "meta[name='keywords']",
|
||||
"attr": "content",
|
||||
"all": False,
|
||||
},
|
||||
# Open Graph meta tags
|
||||
"og_title": {
|
||||
"selector": "meta[property='og:title']",
|
||||
"attr": "content",
|
||||
"all": False,
|
||||
},
|
||||
"og_description": {
|
||||
"selector": "meta[property='og:description']",
|
||||
"attr": "content",
|
||||
"all": False,
|
||||
},
|
||||
# Main content areas (prioritized selectors)
|
||||
"article": {"selector": "article", "all": True, "join_with": "\n\n"},
|
||||
"main": {"selector": "main", "all": True, "join_with": "\n\n"},
|
||||
# Semantic content sections
|
||||
"headers_h1": {"selector": "h1", "all": True, "join_with": "\n"},
|
||||
"headers_h2": {"selector": "h2", "all": True, "join_with": "\n"},
|
||||
"headers_h3": {"selector": "h3", "all": True, "join_with": "\n"},
|
||||
"headers_h4": {"selector": "h4", "all": True, "join_with": "\n"},
|
||||
"headers_h5": {"selector": "h5", "all": True, "join_with": "\n"},
|
||||
"headers_h6": {"selector": "h6", "all": True, "join_with": "\n"},
|
||||
# Text content
|
||||
"paragraphs": {"selector": "p", "all": True, "join_with": "\n\n"},
|
||||
"blockquotes": {"selector": "blockquote", "all": True, "join_with": "\n\n"},
|
||||
"preformatted": {"selector": "pre", "all": True, "join_with": "\n\n"},
|
||||
# Lists
|
||||
"ordered_lists": {"selector": "ol", "all": True, "join_with": "\n"},
|
||||
"unordered_lists": {"selector": "ul", "all": True, "join_with": "\n"},
|
||||
"list_items": {"selector": "li", "all": True, "join_with": "\n"},
|
||||
"definition_lists": {"selector": "dl", "all": True, "join_with": "\n"},
|
||||
# Tables
|
||||
"tables": {"selector": "table", "all": True, "join_with": "\n\n"},
|
||||
"table_captions": {
|
||||
"selector": "caption",
|
||||
"all": True,
|
||||
"join_with": "\n",
|
||||
},
|
||||
# Code blocks
|
||||
"code_blocks": {"selector": "code", "all": True, "join_with": "\n"},
|
||||
# Figures and media descriptions
|
||||
"figures": {"selector": "figure", "all": True, "join_with": "\n\n"},
|
||||
"figcaptions": {"selector": "figcaption", "all": True, "join_with": "\n"},
|
||||
"image_alts": {"selector": "img", "attr": "alt", "all": True, "join_with": " "},
|
||||
# Links (text content, not URLs to avoid clutter)
|
||||
"link_text": {"selector": "a", "all": True, "join_with": " "},
|
||||
# Emphasized text
|
||||
"strong": {"selector": "strong", "all": True, "join_with": " "},
|
||||
"emphasis": {"selector": "em", "all": True, "join_with": " "},
|
||||
"marked": {"selector": "mark", "all": True, "join_with": " "},
|
||||
# Time and data elements
|
||||
"time": {"selector": "time", "all": True, "join_with": " "},
|
||||
"data": {"selector": "data", "all": True, "join_with": " "},
|
||||
# Sections and semantic structure
|
||||
"sections": {"selector": "section", "all": True, "join_with": "\n\n"},
|
||||
"asides": {"selector": "aside", "all": True, "join_with": "\n\n"},
|
||||
"details": {"selector": "details", "all": True, "join_with": "\n"},
|
||||
"summary": {"selector": "summary", "all": True, "join_with": "\n"},
|
||||
# Navigation (may contain important links/structure)
|
||||
"nav": {"selector": "nav", "all": True, "join_with": "\n"},
|
||||
# Footer information
|
||||
"footer": {"selector": "footer", "all": True, "join_with": "\n"},
|
||||
# Divs with specific content roles
|
||||
"content_divs": {
|
||||
"selector": "div[role='main'], div[role='article'], div.content, div#content",
|
||||
"all": True,
|
||||
"join_with": "\n\n",
|
||||
},
|
||||
}
|
||||
|
||||
async def load(
|
||||
self,
|
||||
file_path: str,
|
||||
extraction_rules: dict[str, Any] = None,
|
||||
join_all_matches: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load an HTML file, extract content, and save to storage.
|
||||
|
||||
Args:
|
||||
file_path: Path to the HTML file
|
||||
extraction_rules: Dict of CSS selector rules for content extraction
|
||||
join_all_matches: If True, extract all matching elements for each rule
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Path to the stored extracted text file
|
||||
"""
|
||||
if extraction_rules is None:
|
||||
extraction_rules = self._get_default_extraction_rules()
|
||||
logger.info("Using default comprehensive extraction rules for HTML content")
|
||||
|
||||
logger.info(f"Processing HTML file: {file_path}")
|
||||
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
f.seek(0)
|
||||
html = f.read()
|
||||
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
# Normalize extraction rules
|
||||
normalized_rules: List[ExtractionRule] = []
|
||||
for _, rule in extraction_rules.items():
|
||||
r = self._normalize_rule(rule)
|
||||
if join_all_matches:
|
||||
r.all = True
|
||||
normalized_rules.append(r)
|
||||
|
||||
pieces = []
|
||||
for rule in normalized_rules:
|
||||
text = self._extract_from_html(html, rule)
|
||||
if text:
|
||||
pieces.append(text)
|
||||
|
||||
full_content = " ".join(pieces).strip()
|
||||
|
||||
# remove after defaults for extraction rules
|
||||
# Fallback: If no content extracted, check if the file is plain text (not HTML)
|
||||
if not full_content:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
# If there are no HTML tags, treat as plain text
|
||||
if not soup.find():
|
||||
logger.warning(
|
||||
f"No HTML tags found in {file_path}. Treating as plain text. "
|
||||
"This may happen when content is pre-extracted (e.g., via Tavily with text format)."
|
||||
)
|
||||
full_content = html.decode("utf-8") if isinstance(html, bytes) else html
|
||||
full_content = full_content.strip()
|
||||
|
||||
if not full_content:
|
||||
logger.warning(f"No content extracted from HTML file: {file_path}")
|
||||
|
||||
# Store the extracted content
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, full_content)
|
||||
|
||||
logger.info(f"Extracted {len(full_content)} characters from HTML")
|
||||
return full_file_path
|
||||
|
||||
def _normalize_rule(self, rule: Union[str, Dict[str, Any]]) -> ExtractionRule:
|
||||
"""Normalize an extraction rule to an ExtractionRule dataclass.
|
||||
|
||||
Args:
|
||||
rule: A string (CSS selector) or dict with extraction parameters.
|
||||
|
||||
Returns:
|
||||
ExtractionRule: Normalized extraction rule.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rule is invalid.
|
||||
"""
|
||||
if isinstance(rule, str):
|
||||
return ExtractionRule(selector=rule)
|
||||
if isinstance(rule, dict):
|
||||
return ExtractionRule(
|
||||
selector=rule.get("selector"),
|
||||
xpath=rule.get("xpath"),
|
||||
attr=rule.get("attr"),
|
||||
all=bool(rule.get("all", False)),
|
||||
join_with=rule.get("join_with", " "),
|
||||
)
|
||||
raise ValueError(f"Invalid extraction rule: {rule}")
|
||||
|
||||
def _extract_from_html(self, html: str, rule: ExtractionRule) -> str:
|
||||
"""Extract content from HTML using BeautifulSoup or lxml XPath.
|
||||
|
||||
Args:
|
||||
html: The HTML content to extract from.
|
||||
rule: The extraction rule to apply.
|
||||
|
||||
Returns:
|
||||
str: The extracted content.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If XPath is used but lxml is not installed.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
if rule.xpath:
|
||||
try:
|
||||
from lxml import html as lxml_html
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"XPath requested but lxml is not available. Install lxml or use CSS selectors."
|
||||
)
|
||||
doc = lxml_html.fromstring(html)
|
||||
nodes = doc.xpath(rule.xpath)
|
||||
texts = []
|
||||
for n in nodes:
|
||||
if hasattr(n, "text_content"):
|
||||
texts.append(n.text_content().strip())
|
||||
else:
|
||||
texts.append(str(n).strip())
|
||||
return rule.join_with.join(t for t in texts if t)
|
||||
|
||||
if not rule.selector:
|
||||
return ""
|
||||
|
||||
if rule.all:
|
||||
nodes = soup.select(rule.selector)
|
||||
pieces = []
|
||||
for el in nodes:
|
||||
if rule.attr:
|
||||
val = el.get(rule.attr)
|
||||
if val:
|
||||
pieces.append(val.strip())
|
||||
else:
|
||||
text = el.get_text(strip=True)
|
||||
if text:
|
||||
pieces.append(text)
|
||||
return rule.join_with.join(pieces).strip()
|
||||
else:
|
||||
el = soup.select_one(rule.selector)
|
||||
if el is None:
|
||||
return ""
|
||||
if rule.attr:
|
||||
val = el.get(rule.attr)
|
||||
return (val or "").strip()
|
||||
return el.get_text(strip=True)
|
||||
|
|
@ -23,3 +23,10 @@ try:
|
|||
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from cognee.infrastructure.loaders.external import BeautifulSoupLoader
|
||||
|
||||
supported_loaders[BeautifulSoupLoader.loader_name] = BeautifulSoupLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.tasks.temporal_graph.models import Event
|
||||
|
|
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class UnstructuredLibraryImportError(CogneeConfigurationError):
|
|||
self,
|
||||
message: str = "Import error. Unstructured library is not installed.",
|
||||
name: str = "UnstructuredModuleImportError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,3 +23,6 @@ from .create_authorized_dataset import create_authorized_dataset
|
|||
|
||||
# Check
|
||||
from .check_dataset_name import check_dataset_name
|
||||
|
||||
# Boolean check
|
||||
from .has_dataset_data import has_dataset_data
|
||||
|
|
|
|||
21
cognee/modules/data/methods/has_dataset_data.py
Normal file
21
cognee/modules/data/methods/has_dataset_data.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import DatasetData
|
||||
|
||||
|
||||
async def has_dataset_data(dataset_id: UUID) -> bool:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(DatasetData)
|
||||
.where(DatasetData.dataset_id == dataset_id)
|
||||
)
|
||||
count = await session.execute(count_query)
|
||||
|
||||
return count.scalar_one() > 0
|
||||
|
|
@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
distance = embedding_map.get(relationship_type, None)
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import (
|
||||
|
|
@ -243,10 +244,26 @@ def _process_graph_nodes(
|
|||
ontology_relationships,
|
||||
)
|
||||
|
||||
# Add entity to data chunk
|
||||
if data_chunk.contains is None:
|
||||
data_chunk.contains = []
|
||||
data_chunk.contains.append(entity_node)
|
||||
|
||||
edge_text = "; ".join(
|
||||
[
|
||||
"relationship_name: contains",
|
||||
f"entity_name: {entity_node.name}",
|
||||
f"entity_description: {entity_node.description}",
|
||||
]
|
||||
)
|
||||
|
||||
data_chunk.contains.append(
|
||||
(
|
||||
Edge(
|
||||
relationship_type="contains",
|
||||
edge_text=edge_text,
|
||||
),
|
||||
entity_node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _process_graph_edges(
|
||||
|
|
|
|||
|
|
@ -1,71 +1,70 @@
|
|||
import string
|
||||
from typing import List
|
||||
from collections import Counter
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
|
||||
def _get_top_n_frequent_words(
|
||||
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
|
||||
) -> str:
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
return separator.join(top_words)
|
||||
|
||||
|
||||
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
"""Creates a title by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
|
||||
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id in nodes:
|
||||
continue
|
||||
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _create_title_from_text(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
nodes = _extract_nodes_from_edges(retrieved_edges)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
|
||||
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
import string
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
from collections import Counter
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _top_n_words(text, top_n=first_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
nodes = _get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
|
||||
connections = []
|
||||
for edge in retrieved_edges:
|
||||
source_name = nodes[edge.node1.id]["name"]
|
||||
target_name = nodes[edge.node2.id]["name"]
|
||||
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
|
||||
|
||||
connection_section = "\n".join(connections)
|
||||
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class BinaryData(IngestionData):
|
|||
|
||||
async def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = await get_file_metadata(self.data)
|
||||
self.metadata = await get_file_metadata(self.data, name=self.name)
|
||||
|
||||
if self.metadata["name"] is None:
|
||||
self.metadata["name"] = self.name
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import BinaryIO, Union
|
||||
from typing import BinaryIO, Union, Optional
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from .classify import classify
|
||||
import hashlib
|
||||
|
||||
|
||||
async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
|
||||
async def save_data_to_file(
|
||||
data: Union[str, BinaryIO], filename: str = None, file_extension: Optional[str] = None
|
||||
):
|
||||
storage_config = get_storage_config()
|
||||
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
|
|
@ -21,6 +23,11 @@ async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
|
|||
|
||||
file_name = file_metadata["name"]
|
||||
|
||||
if file_extension is not None:
|
||||
extension = file_extension.lstrip(".")
|
||||
file_name_without_ext = file_name.rsplit(".", 1)[0]
|
||||
file_name = f"{file_name_without_ext}.{extension}"
|
||||
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(file_name, data)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ def get_ontology_resolver_from_env(
|
|||
Supported value: "rdflib".
|
||||
matching_strategy (str): The matching strategy to apply.
|
||||
Supported value: "fuzzy".
|
||||
ontology_file_path (str): Path to the ontology file required for the resolver.
|
||||
ontology_file_path (str): Path to the ontology file(s) required for the resolver.
|
||||
Can be a single path or comma-separated paths for multiple files.
|
||||
|
||||
Returns:
|
||||
BaseOntologyResolver: An instance of the requested ontology resolver.
|
||||
|
|
@ -31,8 +32,13 @@ def get_ontology_resolver_from_env(
|
|||
or if required parameters are missing.
|
||||
"""
|
||||
if ontology_resolver == "rdflib" and matching_strategy == "fuzzy" and ontology_file_path:
|
||||
if "," in ontology_file_path:
|
||||
file_paths = [path.strip() for path in ontology_file_path.split(",")]
|
||||
else:
|
||||
file_paths = ontology_file_path
|
||||
|
||||
return RDFLibOntologyResolver(
|
||||
matching_strategy=FuzzyMatchingStrategy(), ontology_file=ontology_file_path
|
||||
matching_strategy=FuzzyMatchingStrategy(), ontology_file=file_paths
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import difflib
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional, Any
|
||||
from typing import List, Tuple, Dict, Optional, Any, Union
|
||||
from rdflib import Graph, URIRef, RDF, RDFS, OWL
|
||||
|
||||
from cognee.modules.ontology.exceptions import (
|
||||
|
|
@ -26,22 +26,50 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
ontology_file: Optional[str] = None,
|
||||
ontology_file: Optional[Union[str, List[str]]] = None,
|
||||
matching_strategy: Optional[MatchingStrategy] = None,
|
||||
) -> None:
|
||||
super().__init__(matching_strategy)
|
||||
self.ontology_file = ontology_file
|
||||
try:
|
||||
if ontology_file and os.path.exists(ontology_file):
|
||||
files_to_load = []
|
||||
if ontology_file is not None:
|
||||
if isinstance(ontology_file, str):
|
||||
files_to_load = [ontology_file]
|
||||
elif isinstance(ontology_file, list):
|
||||
files_to_load = ontology_file
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
|
||||
)
|
||||
|
||||
if files_to_load:
|
||||
self.graph = Graph()
|
||||
self.graph.parse(ontology_file)
|
||||
logger.info("Ontology loaded successfully from file: %s", ontology_file)
|
||||
loaded_files = []
|
||||
for file_path in files_to_load:
|
||||
if os.path.exists(file_path):
|
||||
self.graph.parse(file_path)
|
||||
loaded_files.append(file_path)
|
||||
logger.info("Ontology loaded successfully from file: %s", file_path)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ontology file '%s' not found. Skipping this file.",
|
||||
file_path,
|
||||
)
|
||||
|
||||
if not loaded_files:
|
||||
logger.info(
|
||||
"No valid ontology files found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
else:
|
||||
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
||||
else:
|
||||
logger.info(
|
||||
"Ontology file '%s' not found. No owl ontology will be attached to the graph.",
|
||||
ontology_file,
|
||||
"No ontology file provided. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
|
||||
self.build_lookup()
|
||||
except Exception as e:
|
||||
logger.error("Failed to load ontology", exc_info=e)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,6 @@ class PipelineRunFailedError(CogneeSystemError):
|
|||
self,
|
||||
message: str = "Pipeline run failed.",
|
||||
name: str = "PipelineRunFailedError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|||
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
||||
check_pipeline_run_qualification,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("cognee.pipeline")
|
||||
|
||||
|
|
@ -80,7 +81,14 @@ async def run_pipeline_per_dataset(
|
|||
return
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch
|
||||
tasks,
|
||||
dataset.id,
|
||||
data,
|
||||
user,
|
||||
pipeline_name,
|
||||
context,
|
||||
incremental_loading,
|
||||
data_per_batch,
|
||||
)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import inspect
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
from ..tasks.task import Task
|
||||
|
||||
|
|
@ -25,6 +26,8 @@ async def handle_task(
|
|||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -46,6 +49,8 @@ async def handle_task(
|
|||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
|
|
@ -58,6 +63,8 @@ async def handle_task(
|
|||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from cognee.modules.settings import get_current_settings
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
from .run_tasks_base import run_tasks_base
|
||||
from ..tasks.task import Task
|
||||
|
|
@ -26,6 +27,8 @@ async def run_tasks_with_telemetry(
|
|||
user.id,
|
||||
additional_properties={
|
||||
"pipeline_name": str(pipeline_name),
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
}
|
||||
| config,
|
||||
)
|
||||
|
|
@ -39,7 +42,10 @@ async def run_tasks_with_telemetry(
|
|||
user.id,
|
||||
additional_properties={
|
||||
"pipeline_name": str(pipeline_name),
|
||||
},
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
}
|
||||
| config,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
|
|
@ -53,6 +59,8 @@ async def run_tasks_with_telemetry(
|
|||
user.id,
|
||||
additional_properties={
|
||||
"pipeline_name": str(pipeline_name),
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
}
|
||||
| config,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, List
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
|
||||
logger = get_logger("entity_completion_retriever")
|
||||
|
|
@ -77,7 +84,9 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
logger.error(f"Context retrieval failed: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
|
|
@ -91,6 +100,8 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
- query (str): The query string for which completion is being generated.
|
||||
- context (Optional[Any]): Optional context to be used for generating completion;
|
||||
fetched if not provided. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -105,12 +116,41 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
if context is None:
|
||||
return ["No relevant entities found for the query."]
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(str(context)),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ class BaseGraphRetriever(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Generates a response using the query and optional context (triplets)."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ class BaseRetriever(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Generates a response using the query and optional context."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
||||
return chunk_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using document chunks context.
|
||||
|
||||
|
|
@ -74,6 +76,8 @@ class ChunksRetriever(BaseRetriever):
|
|||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -207,8 +207,26 @@ class CodeRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the code files context."""
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the code files context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve code context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The code files context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
|
@ -67,7 +74,9 @@ class CompletionRetriever(BaseRetriever):
|
|||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -80,6 +89,8 @@ class CompletionRetriever(BaseRetriever):
|
|||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -89,11 +100,41 @@ class CompletionRetriever(BaseRetriever):
|
|||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -44,13 +44,21 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
"""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
return []
|
||||
|
||||
result = await graph_engine.query(query)
|
||||
except Exception as e:
|
||||
logger.error("Failed to execture cypher search retrieval: %s", str(e))
|
||||
raise CypherSearchError() from e
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
|
|
@ -62,6 +70,8 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
- query (str): The query to retrieve context.
|
||||
- context (Optional[Any]): Optional context to use, otherwise fetched using the
|
||||
query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Optional, List, Type
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -47,6 +54,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
|
@ -64,6 +72,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): The existing context to use for enhancing the query; if
|
||||
None, it will be initialized from triplets generated for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
|
||||
|
|
@ -115,17 +125,46 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
round_idx += 1
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context_text),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,41 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List, Type, Any
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import (
|
||||
generate_structured_completion,
|
||||
summarize_text,
|
||||
)
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _as_answer_text(completion: Any) -> str:
|
||||
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
||||
if isinstance(completion, str):
|
||||
return completion
|
||||
if isinstance(completion, BaseModel):
|
||||
# Add notice that this is a structured response
|
||||
json_str = completion.model_dump_json(indent=2)
|
||||
return f"[Structured Response]\n{json_str}"
|
||||
try:
|
||||
return json.dumps(completion, indent=2)
|
||||
except TypeError:
|
||||
return str(completion)
|
||||
|
||||
|
||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||
"""
|
||||
Handles graph completion by generating responses based on a series of interactions with
|
||||
|
|
@ -18,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
questions based on reasoning. The public methods are:
|
||||
|
||||
- get_completion
|
||||
- get_structured_completion
|
||||
|
||||
Instance variables include:
|
||||
- validation_system_prompt_path
|
||||
|
|
@ -54,33 +81,30 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
self.followup_system_prompt_path = followup_system_prompt_path
|
||||
self.followup_user_prompt_path = followup_user_prompt_path
|
||||
|
||||
async def get_completion(
|
||||
async def _run_cot_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
conversation_history: str = "",
|
||||
max_iter: int = 4,
|
||||
response_model: Type = str,
|
||||
) -> tuple[Any, str, List[Edge]]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
Run chain-of-thought completion with optional structured output.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
- query: User query
|
||||
- context: Optional pre-fetched context edges
|
||||
- conversation_history: Optional conversation history string
|
||||
- max_iter: Maximum CoT iterations
|
||||
- response_model: Type for structured output (str for plain text)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
- completion_result: The generated completion (string or structured model)
|
||||
- context_text: The resolved context text
|
||||
- triplets: The list of triplets used
|
||||
"""
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
|
|
@ -97,16 +121,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_completion(
|
||||
completion = await generate_structured_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history if conversation_history else None,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
|
||||
if round_idx < max_iter:
|
||||
valid_args = {"query": query, "answer": completion, "context": context_text}
|
||||
answer_text = _as_answer_text(completion)
|
||||
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
||||
valid_user_prompt = render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
|
|
@ -119,7 +148,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
system_prompt=valid_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
||||
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
||||
followup_prompt = render_prompt(
|
||||
filename=self.followup_user_prompt_path, context=followup_args
|
||||
)
|
||||
|
|
@ -134,9 +163,110 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||
)
|
||||
|
||||
return completion, context_text, triplets
|
||||
|
||||
async def get_structured_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter: int = 4,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""
|
||||
Generate structured completion responses based on a user query and contextual information.
|
||||
|
||||
This method applies the same chain-of-thought logic as get_completion but returns
|
||||
structured output using the provided response model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The generated structured completion based on the response model.
|
||||
"""
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
# Load conversation history if enabled
|
||||
conversation_history = ""
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
completion, context_text, triplets = await self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
conversation_history=conversation_history,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
# Save to session cache if enabled
|
||||
if session_save:
|
||||
context_summary = await summarize_text(context_text)
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=str(completion),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
completion = await self.get_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
max_iter=max_iter,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -1,20 +1,26 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, Type, List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
logger = get_logger("GraphCompletionRetriever")
|
||||
|
||||
|
|
@ -118,6 +124,13 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
- str: A string representing the resolved context from the retrieved triplets, or an
|
||||
empty string if no triplets are found.
|
||||
"""
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
return []
|
||||
|
||||
triplets = await self.get_triplets(query)
|
||||
|
||||
if len(triplets) == 0:
|
||||
|
|
@ -132,6 +145,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
|
@ -142,6 +156,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, context is retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -155,19 +171,47 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
context_text = await resolve_edges_to_text(triplets)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context_text),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
|
|
|
|||
|
|
@ -116,8 +116,26 @@ class LexicalRetriever(BaseRetriever):
|
|||
else:
|
||||
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns context for the given query (retrieves if not provided)."""
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns context for the given query (retrieves if not provided).
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -122,10 +122,17 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
query.
|
||||
"""
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
return []
|
||||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns a completion based on the query and context.
|
||||
|
||||
|
|
@ -139,6 +146,8 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
- query (str): The natural language query to get a completion from.
|
||||
- context (Optional[Any]): The context in which to base the completion; if not
|
||||
provided, it will be retrieved using the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using summaries context.
|
||||
|
||||
|
|
@ -75,6 +77,8 @@ class SummariesRetriever(BaseRetriever):
|
|||
- query (str): The search query for generating the completion.
|
||||
- context (Optional[Any]): Optional context for the completion; if not provided,
|
||||
will be retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -1,16 +1,22 @@
|
|||
import os
|
||||
import asyncio
|
||||
from typing import Any, Optional, List, Type
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from operator import itemgetter
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
from cognee.tasks.temporal_graph.models import QueryInterval
|
||||
|
||||
|
|
@ -73,7 +79,11 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
time_now = datetime.now().strftime("%d-%m-%Y")
|
||||
|
||||
system_prompt = render_prompt(
|
||||
prompt_path, {"time_now": time_now}, base_directory=base_directory
|
||||
)
|
||||
|
||||
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
|
||||
|
||||
|
|
@ -102,8 +112,6 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
triplets = []
|
||||
|
||||
if time_from and time_to:
|
||||
ids = await graph_engine.collect_time_ids(time_from=time_from, time_to=time_to)
|
||||
elif time_from:
|
||||
|
|
@ -137,17 +145,63 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]:
|
||||
"""Generates a response using the query and optional context."""
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[str]): Optional context to use; if None, it will be
|
||||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated completion.
|
||||
"""
|
||||
if not context:
|
||||
context = await self.get_context(query=query)
|
||||
|
||||
if context:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def get_memory_fragment(
|
|||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,49 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Type, Any
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
|
||||
async def generate_structured_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""Generates a structured completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
||||
if conversation_history:
|
||||
#:TODO: I would separate the history and put it into the system prompt but we have to test what works best with longer convos
|
||||
system_prompt = conversation_history + "\nTASK:" + system_prompt
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
return await generate_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
|
|
|||
156
cognee/modules/retrieval/utils/session_cache.py
Normal file
156
cognee/modules/retrieval/utils/session_cache.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("session_cache")
|
||||
|
||||
|
||||
async def save_conversation_history(
|
||||
query: str,
|
||||
context_summary: str,
|
||||
answer: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Saves Q&A interaction to the session cache if user is authenticated and caching is enabled.
|
||||
|
||||
Handles cache unavailability gracefully by logging warnings instead of failing.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query/question.
|
||||
- context_summary (str): Summarized context used for generating the answer.
|
||||
- answer (str): The generated answer/completion.
|
||||
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if successfully saved to cache, False otherwise.
|
||||
"""
|
||||
try:
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
|
||||
if not (user_id and cache_config.caching):
|
||||
logger.debug("Session caching disabled or user not authenticated")
|
||||
return False
|
||||
|
||||
if session_id is None:
|
||||
session_id = "default_session"
|
||||
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
|
||||
if cache_engine is None:
|
||||
logger.warning("Cache engine not available, skipping session save")
|
||||
return False
|
||||
|
||||
await cache_engine.add_qa(
|
||||
str(user_id),
|
||||
session_id=session_id,
|
||||
question=query,
|
||||
context=context_summary,
|
||||
answer=answer,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully saved Q&A to session cache: user_id={user_id}, session_id={session_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except CacheConnectionError as e:
|
||||
logger.warning(f"Cache unavailable, continuing without session save: {e.message}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error saving to session cache: {type(e).__name__}: {str(e)}. Continuing without caching."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def get_conversation_history(
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves conversation history from cache and formats it as text.
|
||||
|
||||
Returns formatted conversation history with time, question, context, and answer
|
||||
for the last N Q&A pairs (N is determined by cache engine default).
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- session_id (Optional[str]): Session identifier. Defaults to 'default_session' if None.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: Formatted conversation history string, or empty string if no history or error.
|
||||
|
||||
Format:
|
||||
-------
|
||||
|
||||
Previous conversation:
|
||||
|
||||
[2024-01-15 10:30:45]
|
||||
QUESTION: What is X?
|
||||
CONTEXT: X is a concept...
|
||||
ANSWER: X is...
|
||||
|
||||
[2024-01-15 10:31:20]
|
||||
QUESTION: How does Y work?
|
||||
CONTEXT: Y is related to...
|
||||
ANSWER: Y works by...
|
||||
"""
|
||||
try:
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
|
||||
if not (user_id and cache_config.caching):
|
||||
logger.debug("Session caching disabled or user not authenticated")
|
||||
return ""
|
||||
|
||||
if session_id is None:
|
||||
session_id = "default_session"
|
||||
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
|
||||
if cache_engine is None:
|
||||
logger.warning("Cache engine not available, skipping conversation history retrieval")
|
||||
return ""
|
||||
|
||||
history_entries = await cache_engine.get_latest_qa(str(user_id), session_id)
|
||||
|
||||
if not history_entries:
|
||||
logger.debug("No conversation history found")
|
||||
return ""
|
||||
|
||||
history_text = "Previous conversation:\n\n"
|
||||
for entry in history_entries:
|
||||
history_text += f"[{entry.get('time', 'Unknown time')}]\n"
|
||||
history_text += f"QUESTION: {entry.get('question', '')}\n"
|
||||
history_text += f"CONTEXT: {entry.get('context', '')}\n"
|
||||
history_text += f"ANSWER: {entry.get('answer', '')}\n\n"
|
||||
|
||||
logger.debug(f"Retrieved {len(history_entries)} conversation history entries")
|
||||
return history_text
|
||||
|
||||
except CacheConnectionError as e:
|
||||
logger.warning(f"Cache unavailable, continuing without conversation history: {e.message}")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error retrieving conversation history: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
return ""
|
||||
|
|
@ -1,12 +1,16 @@
|
|||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .get_search_type_tools import get_search_type_tools
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def no_access_control_search(
|
||||
query_type: SearchType,
|
||||
|
|
@ -19,6 +23,7 @@ async def no_access_control_search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
|
|
@ -31,6 +36,12 @@ async def no_access_control_search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
# TODO: we can log here, but not all search types use graph. Still keeping this here for reviewer input
|
||||
logger.warning("Search attempt on an empty knowledge graph")
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, get_context] = search_tools
|
||||
|
||||
|
|
@ -38,7 +49,7 @@ async def no_access_control_search(
|
|||
return None, await get_context(query_text), []
|
||||
|
||||
context = await get_context(query_text)
|
||||
result = await get_completion(query_text, context)
|
||||
result = await get_completion(query_text, context, session_id=session_id)
|
||||
else:
|
||||
unknown_tool = search_tools[0]
|
||||
result = await unknown_tool(query_text)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from uuid import UUID
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
|
||||
|
|
@ -22,11 +24,13 @@ from cognee.modules.data.models import Dataset
|
|||
from cognee.modules.data.methods.get_authorized_existing_datasets import (
|
||||
get_authorized_existing_datasets,
|
||||
)
|
||||
|
||||
from cognee import __version__ as cognee_version
|
||||
from .get_search_type_tools import get_search_type_tools
|
||||
from .no_access_control_search import no_access_control_search
|
||||
from ..utils.prepare_search_result import prepare_search_result
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def search(
|
||||
query_text: str,
|
||||
|
|
@ -42,6 +46,7 @@ async def search(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
"""
|
||||
|
||||
|
|
@ -59,7 +64,14 @@ async def search(
|
|||
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
|
||||
"""
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
send_telemetry(
|
||||
"cognee.search EXECUTION STARTED",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
},
|
||||
)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
|
|
@ -77,6 +89,7 @@ async def search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
)
|
||||
else:
|
||||
search_results = [
|
||||
|
|
@ -91,10 +104,18 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
)
|
||||
]
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
send_telemetry(
|
||||
"cognee.search EXECUTION COMPLETED",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"cognee_version": cognee_version,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
|
||||
},
|
||||
)
|
||||
|
||||
await log_result(
|
||||
query.id,
|
||||
|
|
@ -195,6 +216,7 @@ async def authorized_search(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
|
|
@ -221,6 +243,7 @@ async def authorized_search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
|
@ -263,7 +286,7 @@ async def authorized_search(
|
|||
return combined_context
|
||||
|
||||
combined_context = prepare_combined_context(context)
|
||||
completion = await get_completion(query_text, combined_context)
|
||||
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
||||
|
||||
return completion, combined_context, datasets
|
||||
|
||||
|
|
@ -280,6 +303,7 @@ async def authorized_search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
@ -298,6 +322,7 @@ async def search_in_datasets_context(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
|
|
@ -317,10 +342,30 @@ async def search_in_datasets_context(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
||||
if is_empty:
|
||||
# TODO: we can log here, but not all search types use graph. Still keeping this here for reviewer input
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
|
||||
dataset_data = await get_dataset_data(dataset.id)
|
||||
|
||||
if len(dataset_data) > 0:
|
||||
logger.warning(
|
||||
f"Dataset '{dataset.name}' has {len(dataset_data)} data item(s) but the knowledge graph is empty. "
|
||||
"Please run cognify to process the data before searching."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Search attempt on an empty knowledge graph - no data has been added to this dataset"
|
||||
)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -340,7 +385,7 @@ async def search_in_datasets_context(
|
|||
return None, await get_context(query_text), [dataset]
|
||||
|
||||
search_context = context or await get_context(query_text)
|
||||
search_result = await get_completion(query_text, search_context)
|
||||
search_result = await get_completion(query_text, search_context, session_id=session_id)
|
||||
|
||||
return search_result, search_context, [dataset]
|
||||
else:
|
||||
|
|
@ -365,6 +410,7 @@ async def search_in_datasets_context(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,7 @@ async def get_default_user() -> SimpleNamespace:
|
|||
if user is None:
|
||||
return await create_default_user()
|
||||
|
||||
# We return a SimpleNamespace to have the same user type as our SaaS
|
||||
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
||||
auth_data = SimpleNamespace(
|
||||
id=user.id, email=user.email, tenant_id=user.tenant_id, roles=[]
|
||||
)
|
||||
return auth_data
|
||||
return user
|
||||
except Exception as error:
|
||||
if "principals" in str(error.args):
|
||||
raise DatabaseNotCreatedError() from error
|
||||
|
|
|
|||
|
|
@ -16,17 +16,17 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
|
||||
nodes_list = []
|
||||
color_map = {
|
||||
"Entity": "#f47710",
|
||||
"EntityType": "#6510f4",
|
||||
"DocumentChunk": "#801212",
|
||||
"TextSummary": "#1077f4",
|
||||
"TableRow": "#f47710",
|
||||
"TableType": "#6510f4",
|
||||
"ColumnValue": "#13613a",
|
||||
"SchemaTable": "#f47710",
|
||||
"DatabaseSchema": "#6510f4",
|
||||
"SchemaRelationship": "#13613a",
|
||||
"default": "#D3D3D3",
|
||||
"Entity": "#5C10F4",
|
||||
"EntityType": "#A550FF",
|
||||
"DocumentChunk": "#0DFF00",
|
||||
"TextSummary": "#5C10F4",
|
||||
"TableRow": "#A550FF",
|
||||
"TableType": "#5C10F4",
|
||||
"ColumnValue": "#757470",
|
||||
"SchemaTable": "#A550FF",
|
||||
"DatabaseSchema": "#5C10F4",
|
||||
"SchemaRelationship": "#323332",
|
||||
"default": "#D8D8D8",
|
||||
}
|
||||
|
||||
for node_id, node_info in nodes_data:
|
||||
|
|
@ -98,16 +98,19 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
<head>
|
||||
<meta charset="utf-8">
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<script src="https://d3js.org/d3-contour.v1.min.js"></script>
|
||||
<style>
|
||||
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
|
||||
|
||||
svg { width: 100vw; height: 100vh; display: block; }
|
||||
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
|
||||
.links line.weighted { stroke: rgba(255, 215, 0, 0.7); }
|
||||
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
|
||||
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
.links line { stroke: rgba(160, 160, 160, 0.25); stroke-width: 1.5px; stroke-linecap: round; }
|
||||
.links line.weighted { stroke: rgba(255, 215, 0, 0.4); }
|
||||
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.45); }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; }
|
||||
.node-label { font-size: 5px; font-weight: bold; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; paint-order: stroke; stroke: rgba(50,51,50,0.75); stroke-width: 1px; }
|
||||
|
||||
.density path { mix-blend-mode: screen; }
|
||||
|
||||
.tooltip {
|
||||
position: absolute;
|
||||
|
|
@ -125,11 +128,32 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
max-width: 300px;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
#info-panel {
|
||||
position: fixed;
|
||||
left: 12px;
|
||||
top: 12px;
|
||||
width: 340px;
|
||||
max-height: calc(100vh - 24px);
|
||||
overflow: auto;
|
||||
background: rgba(50, 51, 50, 0.7);
|
||||
backdrop-filter: blur(6px);
|
||||
border: 1px solid rgba(216, 216, 216, 0.35);
|
||||
border-radius: 8px;
|
||||
color: #F4F4F4;
|
||||
padding: 12px 14px;
|
||||
z-index: 1100;
|
||||
}
|
||||
#info-panel h3 { margin: 0 0 8px 0; font-size: 14px; color: #F4F4F4; }
|
||||
#info-panel .kv { font-size: 12px; line-height: 1.4; }
|
||||
#info-panel .kv .k { color: #D8D8D8; }
|
||||
#info-panel .kv .v { color: #F4F4F4; }
|
||||
#info-panel .placeholder { opacity: 0.7; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<svg></svg>
|
||||
<div class="tooltip" id="tooltip"></div>
|
||||
<div id="info-panel"><div class="placeholder">Hover a node or edge to inspect details</div></div>
|
||||
<script>
|
||||
var nodes = {nodes};
|
||||
var links = {links};
|
||||
|
|
@ -140,19 +164,141 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
|
||||
var container = svg.append("g");
|
||||
var tooltip = d3.select("#tooltip");
|
||||
var infoPanel = d3.select('#info-panel');
|
||||
|
||||
function renderInfo(title, entries){
|
||||
function esc(s){ return String(s).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>'); }
|
||||
var html = '<h3>' + esc(title) + '</h3>';
|
||||
html += '<div class="kv">';
|
||||
entries.forEach(function(e){
|
||||
html += '<div><span class="k">' + esc(e.k) + ':</span> <span class="v">' + esc(e.v) + '</span></div>';
|
||||
});
|
||||
html += '</div>';
|
||||
infoPanel.html(html);
|
||||
}
|
||||
function pickDescription(obj){
|
||||
if (!obj) return null;
|
||||
var keys = ['description','summary','text','content'];
|
||||
for (var i=0; i<keys.length; i++){
|
||||
var v = obj[keys[i]];
|
||||
if (typeof v === 'string' && v.trim()) return v.trim();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
function truncate(s, n){ if (!s) return s; return s.length > n ? (s.slice(0, n) + '…') : s; }
|
||||
function renderNodeInfo(n){
|
||||
var entries = [];
|
||||
if (n.name) entries.push({k:'Name', v: n.name});
|
||||
if (n.type) entries.push({k:'Type', v: n.type});
|
||||
if (n.id) entries.push({k:'ID', v: n.id});
|
||||
var desc = pickDescription(n) || pickDescription(n.properties);
|
||||
if (desc) entries.push({k:'Description', v: truncate(desc.replace(/\s+/g,' ').trim(), 280)});
|
||||
if (n.properties) {
|
||||
Object.keys(n.properties).slice(0, 12).forEach(function(key){
|
||||
var v = n.properties[key];
|
||||
if (v !== undefined && v !== null && typeof v !== 'object') entries.push({k: key, v: String(v)});
|
||||
});
|
||||
}
|
||||
renderInfo(n.name || 'Node', entries);
|
||||
}
|
||||
function renderEdgeInfo(e){
|
||||
var entries = [];
|
||||
if (e.relation) entries.push({k:'Relation', v: e.relation});
|
||||
if (e.weight !== undefined && e.weight !== null) entries.push({k:'Weight', v: e.weight});
|
||||
if (e.all_weights && Object.keys(e.all_weights).length){
|
||||
Object.keys(e.all_weights).slice(0, 8).forEach(function(k){ entries.push({k: 'w.'+k, v: e.all_weights[k]}); });
|
||||
}
|
||||
if (e.relationship_type) entries.push({k:'Type', v: e.relationship_type});
|
||||
var edesc = pickDescription(e.edge_info);
|
||||
if (edesc) entries.push({k:'Description', v: truncate(edesc.replace(/\s+/g,' ').trim(), 280)});
|
||||
renderInfo('Edge', entries);
|
||||
}
|
||||
|
||||
// Basic runtime diagnostics
|
||||
console.log('[Cognee Visualization] nodes:', nodes ? nodes.length : 0, 'links:', links ? links.length : 0);
|
||||
window.addEventListener('error', function(e){
|
||||
try {
|
||||
tooltip.html('<strong>Error:</strong> ' + e.message)
|
||||
.style('left', '12px')
|
||||
.style('top', '12px')
|
||||
.style('opacity', 1);
|
||||
} catch(_) {}
|
||||
});
|
||||
|
||||
// Normalize node IDs and link endpoints for robustness
|
||||
function resolveId(d){ return (d && (d.id || d.node_id || d.uuid || d.external_id || d.name)) || undefined; }
|
||||
if (Array.isArray(nodes)) {
|
||||
nodes.forEach(function(n){ var id = resolveId(n); if (id !== undefined) n.id = id; });
|
||||
}
|
||||
if (Array.isArray(links)) {
|
||||
links.forEach(function(l){
|
||||
if (typeof l.source === 'object') l.source = resolveId(l.source);
|
||||
if (typeof l.target === 'object') l.target = resolveId(l.target);
|
||||
});
|
||||
}
|
||||
|
||||
if (!nodes || nodes.length === 0) {
|
||||
container.append('text')
|
||||
.attr('x', width / 2)
|
||||
.attr('y', height / 2)
|
||||
.attr('fill', '#fff')
|
||||
.attr('font-size', 14)
|
||||
.attr('text-anchor', 'middle')
|
||||
.text('No graph data available');
|
||||
}
|
||||
|
||||
// Visual defs - reusable glow
|
||||
var defs = svg.append("defs");
|
||||
var glow = defs.append("filter").attr("id", "glow")
|
||||
.attr("x", "-30%")
|
||||
.attr("y", "-30%")
|
||||
.attr("width", "160%")
|
||||
.attr("height", "160%");
|
||||
glow.append("feGaussianBlur").attr("stdDeviation", 8).attr("result", "coloredBlur");
|
||||
var feMerge = glow.append("feMerge");
|
||||
feMerge.append("feMergeNode").attr("in", "coloredBlur");
|
||||
feMerge.append("feMergeNode").attr("in", "SourceGraphic");
|
||||
|
||||
// Stronger glow for hovered adjacency
|
||||
var glowStrong = defs.append("filter").attr("id", "glow-strong")
|
||||
.attr("x", "-40%")
|
||||
.attr("y", "-40%")
|
||||
.attr("width", "180%")
|
||||
.attr("height", "180%");
|
||||
glowStrong.append("feGaussianBlur").attr("stdDeviation", 14).attr("result", "coloredBlur");
|
||||
var feMerge2 = glowStrong.append("feMerge");
|
||||
feMerge2.append("feMergeNode").attr("in", "coloredBlur");
|
||||
feMerge2.append("feMergeNode").attr("in", "SourceGraphic");
|
||||
|
||||
var currentTransform = d3.zoomIdentity;
|
||||
var densityZoomTimer = null;
|
||||
var isInteracting = false;
|
||||
var labelBaseSize = 10;
|
||||
function getGroupKey(d){ return d && (d.type || d.category || d.group || d.color) || 'default'; }
|
||||
|
||||
var simulation = d3.forceSimulation(nodes)
|
||||
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
|
||||
.force("charge", d3.forceManyBody().strength(-275))
|
||||
.force("link", d3.forceLink(links).id(function(d){ return d.id; }).distance(100).strength(0.2))
|
||||
.force("charge", d3.forceManyBody().strength(-180))
|
||||
.force("collide", d3.forceCollide().radius(16).iterations(2))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2))
|
||||
.force("x", d3.forceX().strength(0.1).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.1).y(height / 2));
|
||||
.force("x", d3.forceX().strength(0.06).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.06).y(height / 2))
|
||||
.alphaDecay(0.06)
|
||||
.velocityDecay(0.6);
|
||||
|
||||
// Density layer (sibling of container to avoid double transforms)
|
||||
var densityLayer = svg.append("g")
|
||||
.attr("class", "density")
|
||||
.style("pointer-events", "none");
|
||||
if (densityLayer.lower) densityLayer.lower();
|
||||
|
||||
var link = container.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line")
|
||||
.data(links)
|
||||
.enter().append("line")
|
||||
.style("opacity", 0)
|
||||
.style("pointer-events", "none")
|
||||
.attr("stroke-width", d => {
|
||||
if (d.weight) return Math.max(2, d.weight * 5);
|
||||
if (d.all_weights && Object.keys(d.all_weights).length > 0) {
|
||||
|
|
@ -168,6 +314,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
})
|
||||
.on("mouseover", function(d) {
|
||||
// Create tooltip content for edge
|
||||
renderEdgeInfo(d);
|
||||
var content = "<strong>Edge Information</strong><br/>";
|
||||
content += "Relationship: " + d.relation + "<br/>";
|
||||
|
||||
|
|
@ -212,6 +359,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
.data(links)
|
||||
.enter().append("text")
|
||||
.attr("class", "edge-label")
|
||||
.style("opacity", 0)
|
||||
.text(d => {
|
||||
var label = d.relation;
|
||||
if (d.all_weights && Object.keys(d.all_weights).length > 1) {
|
||||
|
|
@ -232,21 +380,225 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
.data(nodes)
|
||||
.enter().append("g");
|
||||
|
||||
// Color fallback by type when d.color is missing
|
||||
var colorByType = {
|
||||
"Entity": "#5C10F4",
|
||||
"EntityType": "#A550FF",
|
||||
"DocumentChunk": "#0DFF00",
|
||||
"TextSummary": "#5C10F4",
|
||||
"TableRow": "#A550FF",
|
||||
"TableType": "#5C10F4",
|
||||
"ColumnValue": "#757470",
|
||||
"SchemaTable": "#A550FF",
|
||||
"DatabaseSchema": "#5C10F4",
|
||||
"SchemaRelationship": "#323332"
|
||||
};
|
||||
|
||||
var node = nodeGroup.append("circle")
|
||||
.attr("r", 13)
|
||||
.attr("fill", d => d.color)
|
||||
.attr("fill", function(d){ return d.color || colorByType[d.type] || "#D3D3D3"; })
|
||||
.style("filter", "url(#glow)")
|
||||
.attr("shape-rendering", "geometricPrecision")
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("drag", function(d){ dragged(d); updateDensity(); showAdjacency(d); })
|
||||
.on("end", dragended));
|
||||
|
||||
nodeGroup.append("text")
|
||||
// Show links only for hovered node adjacency
|
||||
function isAdjacent(linkDatum, nodeId) {
|
||||
var sid = linkDatum && linkDatum.source && (linkDatum.source.id || linkDatum.source);
|
||||
var tid = linkDatum && linkDatum.target && (linkDatum.target.id || linkDatum.target);
|
||||
return sid === nodeId || tid === nodeId;
|
||||
}
|
||||
|
||||
function showAdjacency(d) {
|
||||
var nodeId = d && (d.id || d.node_id || d.uuid || d.external_id || d.name);
|
||||
if (!nodeId) return;
|
||||
// Build neighbor set
|
||||
var neighborIds = {};
|
||||
neighborIds[nodeId] = true;
|
||||
for (var i = 0; i < links.length; i++) {
|
||||
var l = links[i];
|
||||
var sid = l && l.source && (l.source.id || l.source);
|
||||
var tid = l && l.target && (l.target.id || l.target);
|
||||
if (sid === nodeId) neighborIds[tid] = true;
|
||||
if (tid === nodeId) neighborIds[sid] = true;
|
||||
}
|
||||
|
||||
link
|
||||
.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 0.95 : 0; })
|
||||
.style("stroke", function(l){ return isAdjacent(l, nodeId) ? "rgba(255,255,255,0.95)" : null; })
|
||||
.style("stroke-width", function(l){ return isAdjacent(l, nodeId) ? 2.5 : 1.5; });
|
||||
edgeLabels.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 1 : 0; });
|
||||
densityLayer.style("opacity", 0.35);
|
||||
|
||||
// Highlight neighbor nodes and dim others
|
||||
node
|
||||
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.25; })
|
||||
.style("filter", function(n){ return neighborIds[n.id] ? "url(#glow-strong)" : "url(#glow)"; })
|
||||
.attr("r", function(n){ return neighborIds[n.id] ? 15 : 13; });
|
||||
// Raise highlighted nodes
|
||||
node.filter(function(n){ return neighborIds[n.id]; }).raise();
|
||||
// Neighbor labels brighter
|
||||
nodeGroup.select("text")
|
||||
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.2; })
|
||||
.style("font-size", function(n){
|
||||
var size = neighborIds[n.id] ? Math.min(22, labelBaseSize * 1.25) : labelBaseSize;
|
||||
return size + "px";
|
||||
});
|
||||
}
|
||||
|
||||
function clearAdjacency() {
|
||||
link.style("opacity", 0)
|
||||
.style("stroke", null)
|
||||
.style("stroke-width", 1.5);
|
||||
edgeLabels.style("opacity", 0);
|
||||
densityLayer.style("opacity", 1);
|
||||
node
|
||||
.style("opacity", 1)
|
||||
.style("filter", "url(#glow)")
|
||||
.attr("r", 13);
|
||||
nodeGroup.select("text")
|
||||
.style("opacity", 1)
|
||||
.style("font-size", labelBaseSize + "px");
|
||||
}
|
||||
|
||||
node.on("mouseover", function(d){ showAdjacency(d); })
|
||||
.on("mouseout", function(){ clearAdjacency(); });
|
||||
node.on("mouseover", function(d){ renderNodeInfo(d); tooltip.style('opacity', 0); });
|
||||
// Also bind on the group so labels trigger adjacency too
|
||||
nodeGroup.on("mouseover", function(d){ showAdjacency(d); })
|
||||
.on("mouseout", function(){ clearAdjacency(); });
|
||||
|
||||
// Density always on; no hover gating
|
||||
|
||||
// Add labels sparsely to reduce clutter (every ~50th node), and truncate long text
|
||||
nodeGroup
|
||||
.filter(function(d, i){ return i % 14 === 0; })
|
||||
.append("text")
|
||||
.attr("class", "node-label")
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle")
|
||||
.text(d => d.name);
|
||||
.text(function(d){
|
||||
var s = d && d.name ? String(d.name) : '';
|
||||
return s.length > 40 ? (s.slice(0, 40) + "…") : s;
|
||||
})
|
||||
.style("font-size", labelBaseSize + "px");
|
||||
|
||||
node.append("title").text(d => JSON.stringify(d));
|
||||
function applyLabelSize() {
|
||||
var k = (currentTransform && currentTransform.k) || 1;
|
||||
// Keep labels readable across zoom levels and hide when too small
|
||||
labelBaseSize = Math.max(7, Math.min(18, 10 / Math.sqrt(k)));
|
||||
nodeGroup.select("text")
|
||||
.style("font-size", labelBaseSize + "px")
|
||||
.style("display", (k < 0.35 ? "none" : null));
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Density cloud computation (throttled)
|
||||
var densityTick = 0;
|
||||
var geoPath = d3.geoPath().projection(null);
|
||||
var MAX_POINTS_PER_GROUP = 400;
|
||||
function updateDensity() {
|
||||
try {
|
||||
if (isInteracting) return; // skip during interaction for smoother UX
|
||||
if (typeof d3 === 'undefined' || typeof d3.contourDensity !== 'function') {
|
||||
return; // d3-contour not available; skip gracefully
|
||||
}
|
||||
if (!nodes || nodes.length === 0) return;
|
||||
var usable = nodes.filter(function(d){ return d && typeof d.x === 'number' && isFinite(d.x) && typeof d.y === 'number' && isFinite(d.y); });
|
||||
if (usable.length < 3) return; // not enough positioned points yet
|
||||
|
||||
var t = currentTransform || d3.zoomIdentity;
|
||||
if (t.k && t.k < 0.08) {
|
||||
// Skip density at extreme zoom-out to avoid numerical instability/perf issues
|
||||
densityLayer.selectAll('*').remove();
|
||||
return;
|
||||
}
|
||||
|
||||
function hexToRgb(hex){
|
||||
if (!hex) return {r: 0, g: 200, b: 255};
|
||||
var c = hex.replace('#','');
|
||||
if (c.length === 3) c = c.split('').map(function(x){ return x+x; }).join('');
|
||||
var num = parseInt(c, 16);
|
||||
return { r: (num >> 16) & 255, g: (num >> 8) & 255, b: num & 255 };
|
||||
}
|
||||
|
||||
// Build groups across all nodes
|
||||
var groups = {};
|
||||
for (var i = 0; i < usable.length; i++) {
|
||||
var k = getGroupKey(usable[i]);
|
||||
if (!groups[k]) groups[k] = [];
|
||||
groups[k].push(usable[i]);
|
||||
}
|
||||
|
||||
densityLayer.selectAll('*').remove();
|
||||
|
||||
Object.keys(groups).forEach(function(key){
|
||||
var arr = groups[key];
|
||||
if (!arr || arr.length < 3) return;
|
||||
|
||||
// Transform positions into screen space and sample to cap cost
|
||||
var arrT = [];
|
||||
var step = Math.max(1, Math.floor(arr.length / MAX_POINTS_PER_GROUP));
|
||||
for (var j = 0; j < arr.length; j += step) {
|
||||
var nx = t.applyX(arr[j].x);
|
||||
var ny = t.applyY(arr[j].y);
|
||||
if (isFinite(nx) && isFinite(ny)) {
|
||||
arrT.push({ x: nx, y: ny, type: arr[j].type, color: arr[j].color });
|
||||
}
|
||||
}
|
||||
if (arrT.length < 3) return;
|
||||
|
||||
// Compute adaptive bandwidth based on group spread
|
||||
var cx = 0, cy = 0;
|
||||
for (var k = 0; k < arrT.length; k++){ cx += arrT[k].x; cy += arrT[k].y; }
|
||||
cx /= arrT.length; cy /= arrT.length;
|
||||
var sumR = 0;
|
||||
for (var k2 = 0; k2 < arrT.length; k2++){
|
||||
var dx = arrT[k2].x - cx, dy = arrT[k2].y - cy;
|
||||
sumR += Math.sqrt(dx*dx + dy*dy);
|
||||
}
|
||||
var avgR = sumR / arrT.length;
|
||||
var dynamicBandwidth = Math.max(12, Math.min(80, avgR));
|
||||
var densityBandwidth = dynamicBandwidth / (t.k || 1);
|
||||
|
||||
var contours = d3.contourDensity()
|
||||
.x(function(d){ return d.x; })
|
||||
.y(function(d){ return d.y; })
|
||||
.size([width, height])
|
||||
.bandwidth(densityBandwidth)
|
||||
.thresholds(8)
|
||||
(arrT);
|
||||
|
||||
if (!contours || contours.length === 0) return;
|
||||
var maxVal = d3.max(contours, function(d){ return d.value; }) || 1;
|
||||
|
||||
// Use the first node color in the group or fallback neon palette
|
||||
var baseColor = (arr.find(function(d){ return d && d.color; }) || {}).color || '#00c8ff';
|
||||
var rgb = hexToRgb(baseColor);
|
||||
|
||||
var g = densityLayer.append('g').attr('data-group', key);
|
||||
g.selectAll('path')
|
||||
.data(contours)
|
||||
.enter()
|
||||
.append('path')
|
||||
.attr('d', geoPath)
|
||||
.attr('fill', 'rgb(' + rgb.r + ',' + rgb.g + ',' + rgb.b + ')')
|
||||
.attr('stroke', 'none')
|
||||
.style('opacity', function(d){
|
||||
var v = maxVal ? (d.value / maxVal) : 0;
|
||||
var alpha = Math.pow(Math.max(0, Math.min(1, v)), 1.6); // accentuate clusters
|
||||
return 0.65 * alpha; // up to 0.65 opacity at peak density
|
||||
})
|
||||
.style('filter', 'blur(2px)');
|
||||
});
|
||||
} catch (e) {
|
||||
// Reduce impact of any runtime errors during zoom
|
||||
console.warn('Density update failed:', e);
|
||||
}
|
||||
}
|
||||
|
||||
simulation.on("tick", function() {
|
||||
link.attr("x1", d => d.source.x)
|
||||
|
|
@ -266,16 +618,29 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
.attr("y", d => d.y)
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle");
|
||||
|
||||
densityTick += 1;
|
||||
if (densityTick % 24 === 0) updateDensity();
|
||||
});
|
||||
|
||||
svg.call(d3.zoom().on("zoom", function() {
|
||||
container.attr("transform", d3.event.transform);
|
||||
}));
|
||||
var zoomBehavior = d3.zoom()
|
||||
.on("start", function(){ isInteracting = true; densityLayer.style("opacity", 0.2); })
|
||||
.on("zoom", function(){
|
||||
currentTransform = d3.event.transform;
|
||||
container.attr("transform", currentTransform);
|
||||
})
|
||||
.on("end", function(){
|
||||
if (densityZoomTimer) clearTimeout(densityZoomTimer);
|
||||
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
|
||||
});
|
||||
svg.call(zoomBehavior);
|
||||
|
||||
function dragstarted(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
isInteracting = true;
|
||||
densityLayer.style("opacity", 0.2);
|
||||
}
|
||||
|
||||
function dragged(d) {
|
||||
|
|
@ -287,6 +652,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
if (!d3.event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
if (densityZoomTimer) clearTimeout(densityZoomTimer);
|
||||
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
|
||||
}
|
||||
|
||||
window.addEventListener("resize", function() {
|
||||
|
|
@ -295,7 +662,13 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
svg.attr("width", width).attr("height", height);
|
||||
simulation.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
simulation.alpha(1).restart();
|
||||
updateDensity();
|
||||
applyLabelSize();
|
||||
});
|
||||
|
||||
// Initial density draw
|
||||
updateDensity();
|
||||
applyLabelSize();
|
||||
</script>
|
||||
|
||||
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
|
|
@ -305,8 +678,12 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
</html>
|
||||
"""
|
||||
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||
# Safely embed JSON inside <script> by escaping </ to avoid prematurely closing the tag
|
||||
def _safe_json_embed(obj):
|
||||
return json.dumps(obj).replace("</", "<\\/")
|
||||
|
||||
html_content = html_template.replace("{nodes}", _safe_json_embed(nodes_list))
|
||||
html_content = html_content.replace("{links}", _safe_json_embed(links_list))
|
||||
|
||||
if not destination_file_path:
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,6 @@ class IngestionError(CogneeValidationError):
|
|||
self,
|
||||
message: str = "Failed to load data.",
|
||||
name: str = "IngestionError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import tempfile
|
||||
import structlog
|
||||
import traceback
|
||||
import platform
|
||||
|
|
@ -76,9 +77,38 @@ log_levels = {
|
|||
# Track if structlog logging has been configured
|
||||
_is_structlog_configured = False
|
||||
|
||||
# Path to logs directory
|
||||
LOGS_DIR = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "logs"))
|
||||
LOGS_DIR.mkdir(exist_ok=True) # Create logs dir if it doesn't exist
|
||||
|
||||
def resolve_logs_dir():
|
||||
"""Resolve a writable logs directory.
|
||||
|
||||
Priority:
|
||||
1) BaseConfig.logs_root_directory (respects COGNEE_LOGS_DIR)
|
||||
2) /tmp/cognee_logs (default, best-effort create)
|
||||
|
||||
Returns a Path or None if none are writable/creatable.
|
||||
"""
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
base_config = get_base_config()
|
||||
logs_root_directory = Path(base_config.logs_root_directory)
|
||||
|
||||
try:
|
||||
logs_root_directory.mkdir(parents=True, exist_ok=True)
|
||||
if os.access(logs_root_directory, os.W_OK):
|
||||
return logs_root_directory
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
tmp_log_path = Path(os.path.join("/tmp", "cognee_logs"))
|
||||
tmp_log_path.mkdir(parents=True, exist_ok=True)
|
||||
if os.access(tmp_log_path, os.W_OK):
|
||||
return tmp_log_path
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Maximum number of log files to keep
|
||||
MAX_LOG_FILES = 10
|
||||
|
|
@ -430,28 +460,38 @@ def setup_logging(log_level=None, name=None):
|
|||
stream_handler.setFormatter(console_formatter)
|
||||
stream_handler.setLevel(log_level)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
if root_logger.hasHandlers():
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(stream_handler)
|
||||
|
||||
# Note: root logger needs to be set at NOTSET to allow all messages through and specific stream and file handlers
|
||||
# can define their own levels.
|
||||
root_logger.setLevel(logging.NOTSET)
|
||||
|
||||
# Resolve logs directory with env and safe fallbacks
|
||||
logs_dir = resolve_logs_dir()
|
||||
|
||||
# Check if we already have a log file path from the environment
|
||||
# NOTE: environment variable must be used here as it allows us to
|
||||
# log to a single file with a name based on a timestamp in a multiprocess setting.
|
||||
# Without it, we would have a separate log file for every process.
|
||||
log_file_path = os.environ.get("LOG_FILE_NAME")
|
||||
if not log_file_path:
|
||||
if not log_file_path and logs_dir is not None:
|
||||
# Create a new log file name with the cognee start time
|
||||
start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
log_file_path = os.path.join(LOGS_DIR, f"{start_time}.log")
|
||||
log_file_path = str((logs_dir / f"{start_time}.log").resolve())
|
||||
os.environ["LOG_FILE_NAME"] = log_file_path
|
||||
|
||||
# Create a file handler that uses our custom PlainFileHandler
|
||||
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
|
||||
file_handler.setLevel(DEBUG)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
if root_logger.hasHandlers():
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(stream_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.setLevel(log_level)
|
||||
try:
|
||||
# Create a file handler that uses our custom PlainFileHandler
|
||||
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
|
||||
file_handler.setLevel(DEBUG)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
# Note: Exceptions happen in case of read only file systems or log file path poiting to location where it does
|
||||
# not have write permission. Logging to file is not mandatory so we just log a warning to console.
|
||||
root_logger.warning(f"Warning: Could not create log file handler at {log_file_path}: {e}")
|
||||
|
||||
if log_level > logging.DEBUG:
|
||||
import warnings
|
||||
|
|
@ -466,7 +506,8 @@ def setup_logging(log_level=None, name=None):
|
|||
)
|
||||
|
||||
# Clean up old log files, keeping only the most recent ones
|
||||
cleanup_old_logs(LOGS_DIR, MAX_LOG_FILES)
|
||||
if logs_dir is not None:
|
||||
cleanup_old_logs(logs_dir, MAX_LOG_FILES)
|
||||
|
||||
# Mark logging as configured
|
||||
_is_structlog_configured = True
|
||||
|
|
@ -490,6 +531,10 @@ def setup_logging(log_level=None, name=None):
|
|||
|
||||
# Get a configured logger and log system information
|
||||
logger = structlog.get_logger(name if name else __name__)
|
||||
|
||||
if logs_dir is not None:
|
||||
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
||||
|
||||
# Detailed initialization for regular usage
|
||||
logger.info(
|
||||
"Logging initialized",
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue