Compare commits
3 commits
main
...
LStromann-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c62911be9 | ||
|
|
9982d2f7ed | ||
|
|
5bee392e6c |
303 changed files with 21511 additions and 23722 deletions
127
.coderabbit.yaml
127
.coderabbit.yaml
|
|
@ -1,127 +0,0 @@
|
|||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
# .coderabbit.yaml
|
||||
language: en
|
||||
early_access: false
|
||||
enable_free_tier: true
|
||||
reviews:
|
||||
profile: chill
|
||||
instructions: >-
|
||||
# Code Review Instructions
|
||||
|
||||
- Ensure the code follows best practices and coding standards.
|
||||
- For **Python** code, follow
|
||||
[PEP 20](https://www.python.org/dev/peps/pep-0020/) and
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161)
|
||||
standards.
|
||||
|
||||
# Documentation Review Instructions
|
||||
- Verify that documentation and comments are clear and comprehensive.
|
||||
- Verify that documentation and comments are free of spelling mistakes.
|
||||
|
||||
# Test Code Review Instructions
|
||||
- Ensure that test code is automated, comprehensive, and follows testing best practices.
|
||||
- Verify that all critical functionality is covered by tests.
|
||||
- Ensure that test code follow
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/d840ee9990e65f302ce2a8d78ebe73f6)
|
||||
|
||||
# Misc.
|
||||
- Confirm that the code meets the project's requirements and objectives.
|
||||
- Confirm that copyright years are up-to date whenever a file is changed.
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary_placeholder: '@coderabbitai summary'
|
||||
auto_title_placeholder: '@coderabbitai'
|
||||
review_status: true
|
||||
poem: false
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
changed_files_summary: true
|
||||
path_filters: ['!*.xc*/**', '!node_modules/**', '!dist/**', '!build/**', '!.git/**', '!venv/**', '!__pycache__/**']
|
||||
path_instructions:
|
||||
- path: README.md
|
||||
instructions: >-
|
||||
1. Consider the file 'README.md' the overview/introduction of the project.
|
||||
Also consider the 'README.md' file the first place to look for project documentation.
|
||||
|
||||
2. When reviewing the file 'README.md' it should be linted with help
|
||||
from the tools `markdownlint` and `languagetool`, pointing out any issues.
|
||||
|
||||
3. You may assume the file 'README.md' will contain GitHub flavor Markdown.
|
||||
- path: '**/*.py'
|
||||
instructions: >-
|
||||
When reviewing Python code for this project:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a general rule, undocumented function definitions and class definitions in the project's Python code are assumed incomplete. Please consider suggesting a short summary of the code for any of these incomplete definitions as docstrings when reviewing.
|
||||
- path: cognee/tests/*
|
||||
instructions: >-
|
||||
When reviewing test code:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance, pointing out any violations discovered.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a project rule, Python source files with names prefixed by the string "test_" and located in the project's "tests" directory are the project's unit-testing code. It is safe, albeit a heuristic, to assume these are considered part of the project's minimal acceptance testing unless a justifying exception to this assumption is documented.
|
||||
|
||||
6. As a project rule, any files without extensions and with names prefixed by either the string "check_" or the string "test_", and located in the project's "tests" directory, are the project's non-unit test code. "Non-unit test" in this context refers to any type of testing other than unit testing, such as (but not limited to) functional testing, style linting, regression testing, etc. It can also be assumed that non-unit testing code is usually written as Bash shell scripts.
|
||||
- path: requirements.txt
|
||||
instructions: >-
|
||||
* The project's own Python dependencies are recorded in 'requirements.txt' for production code.
|
||||
|
||||
* The project's testing-specific Python dependencies are recorded in 'tests/requirements.txt' and are used for testing the project.
|
||||
|
||||
* The project's documentation-specific Python dependencies are recorded in 'docs/requirements.txt' and are used only for generating Python-focused documentation for the project. 'docs/requirements.txt' may be absent if not applicable.
|
||||
|
||||
Consider these 'requirements.txt' files the records of truth regarding project dependencies.
|
||||
- path: .github/**
|
||||
instructions: >-
|
||||
* When the project is hosted on GitHub: All GitHub-specific configurations, templates, and tools should be found in the '.github' directory tree.
|
||||
|
||||
* 'actionlint' erroneously generates false positives when dealing with GitHub's `${{ ... }}` syntax in conditionals.
|
||||
|
||||
* 'actionlint' erroneously generates incorrect solutions when suggesting the removal of valid `${{ ... }}` syntax.
|
||||
abort_on_close: true
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
ignore_title_keywords: []
|
||||
labels: []
|
||||
drafts: false
|
||||
base_branches:
|
||||
- dev
|
||||
- main
|
||||
tools:
|
||||
shellcheck:
|
||||
enabled: true
|
||||
ruff:
|
||||
enabled: true
|
||||
configuration:
|
||||
extend_select:
|
||||
- E # Pycodestyle errors (style issues)
|
||||
- F # PyFlakes codes (logical errors)
|
||||
- W # Pycodestyle warnings
|
||||
- N # PEP 8 naming conventions
|
||||
ignore:
|
||||
- W191
|
||||
- W391
|
||||
- E117
|
||||
- D208
|
||||
line_length: 100
|
||||
dummy_variable_rgx: '^(_.*|junk|extra)$' # Variables starting with '_' or named 'junk' or 'extras', are considered dummy variables
|
||||
markdownlint:
|
||||
enabled: true
|
||||
yamllint:
|
||||
enabled: true
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
|
@ -21,10 +21,6 @@ LLM_PROVIDER="openai"
|
|||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
LLM_MAX_TOKENS="16384"
|
||||
# Instructor's modes determine how structured data is requested from and extracted from LLM responses
|
||||
# You can change this type (i.e. mode) via this env variable
|
||||
# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
|
||||
LLM_INSTRUCTOR_MODE=""
|
||||
|
||||
EMBEDDING_PROVIDER="openai"
|
||||
EMBEDDING_MODEL="openai/text-embedding-3-large"
|
||||
|
|
@ -97,8 +93,6 @@ DB_NAME=cognee_db
|
|||
|
||||
# Default (local file-based)
|
||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||
|
||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
|
|
@ -123,8 +117,6 @@ VECTOR_DB_PROVIDER="lancedb"
|
|||
# Not needed if a cloud vector database is not used
|
||||
VECTOR_DB_URL=
|
||||
VECTOR_DB_KEY=
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||
|
||||
################################################################################
|
||||
# 🧩 Ontology resolver settings
|
||||
|
|
@ -177,9 +169,8 @@ REQUIRE_AUTHENTICATION=False
|
|||
# Vector: LanceDB
|
||||
# Graph: KuzuDB
|
||||
#
|
||||
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
|
||||
# Disable mode when using not supported graph/vector databases.
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=True
|
||||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||
|
||||
################################################################################
|
||||
# ☁️ Cloud Sync Settings
|
||||
|
|
|
|||
10
.github/actions/cognee_setup/action.yml
vendored
10
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -10,10 +10,6 @@ inputs:
|
|||
description: "Additional extra dependencies to install (space-separated)"
|
||||
required: false
|
||||
default: ""
|
||||
rebuild-lockfile:
|
||||
description: "Whether to rebuild the uv lockfile"
|
||||
required: false
|
||||
default: "false"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
|
|
@ -30,7 +26,6 @@ runs:
|
|||
enable-cache: true
|
||||
|
||||
- name: Rebuild uv lockfile
|
||||
if: ${{ inputs.rebuild-lockfile == 'true' }}
|
||||
shell: bash
|
||||
run: |
|
||||
rm uv.lock
|
||||
|
|
@ -47,8 +42,3 @@ runs:
|
|||
done
|
||||
fi
|
||||
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS
|
||||
|
||||
- name: Add telemetry identifier for telemetry test and in case telemetry is enabled by accident
|
||||
shell: bash
|
||||
run: |
|
||||
echo "test-machine" > .anon_id
|
||||
|
|
|
|||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
|
|
@ -6,14 +6,6 @@ Please provide a clear, human-generated description of the changes in this PR.
|
|||
DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning.
|
||||
-->
|
||||
|
||||
## Acceptance Criteria
|
||||
<!--
|
||||
* Key requirements to the new feature or modification;
|
||||
* Proof that the changes work and meet the requirements;
|
||||
* Include instructions on how to verify the changes. Describe how to test it locally;
|
||||
* Proof that it's sufficiently tested.
|
||||
-->
|
||||
|
||||
## Type of Change
|
||||
<!-- Please check the relevant option -->
|
||||
- [ ] Bug fix (non-breaking change that fixes an issue)
|
||||
|
|
|
|||
20
.github/release-drafter.yml
vendored
20
.github/release-drafter.yml
vendored
|
|
@ -1,20 +0,0 @@
|
|||
name-template: 'v$NEXT_PATCH_VERSION'
|
||||
tag-template: 'v$NEXT_PATCH_VERSION'
|
||||
|
||||
categories:
|
||||
- title: 'Features'
|
||||
labels: ['feature', 'enhancement']
|
||||
- title: 'Bug Fixes'
|
||||
labels: ['bug', 'fix']
|
||||
- title: 'Maintenance'
|
||||
labels: ['chore', 'refactor', 'ci']
|
||||
|
||||
change-template: '- $TITLE (#$NUMBER) @$AUTHOR'
|
||||
template: |
|
||||
## What’s Changed
|
||||
|
||||
$CHANGES
|
||||
|
||||
## Contributors
|
||||
|
||||
$CONTRIBUTORS
|
||||
33
.github/workflows/basic_tests.yml
vendored
33
.github/workflows/basic_tests.yml
vendored
|
|
@ -75,7 +75,6 @@ jobs:
|
|||
name: Run Unit Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -105,7 +104,6 @@ jobs:
|
|||
name: Run Integration Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -134,7 +132,6 @@ jobs:
|
|||
name: Run Simple Examples
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -164,7 +161,6 @@ jobs:
|
|||
name: Run Simple Examples BAML
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
STRUCTURED_OUTPUT_FRAMEWORK: "BAML"
|
||||
BAML_LLM_PROVIDER: openai
|
||||
BAML_LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
|
|
@ -197,3 +193,32 @@ jobs:
|
|||
|
||||
- name: Run Simple Examples
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
graph-tests:
|
||||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run Graph Tests
|
||||
run: uv run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph
|
||||
|
|
|
|||
3
.github/workflows/cli_tests.yml
vendored
3
.github/workflows/cli_tests.yml
vendored
|
|
@ -39,7 +39,6 @@ jobs:
|
|||
name: CLI Unit Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -67,7 +66,6 @@ jobs:
|
|||
name: CLI Integration Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -95,7 +93,6 @@ jobs:
|
|||
name: CLI Functionality Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
|
|||
8
.github/workflows/db_examples_tests.yml
vendored
8
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -60,8 +60,7 @@ jobs:
|
|||
|
||||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -96,7 +95,7 @@ jobs:
|
|||
|
||||
- name: Run Kuzu Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -142,8 +141,7 @@ jobs:
|
|||
|
||||
- name: Run PGVector Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,7 +47,6 @@ jobs:
|
|||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
197
.github/workflows/e2e_tests.yml
vendored
197
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,7 +147,6 @@ jobs:
|
|||
- name: Run Deduplication Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
|
|
@ -212,56 +211,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||
|
||||
test-dataset-database-handler:
|
||||
name: Test dataset database handlers in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases handler test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||
|
||||
test-dataset-database-deletion:
|
||||
name: Test dataset database deletion in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases deletion test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_delete.py
|
||||
|
||||
test-permissions:
|
||||
name: Test permissions with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -277,7 +226,7 @@ jobs:
|
|||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run permissions test
|
||||
- name: Run parallel databases test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -290,31 +239,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_permissions.py
|
||||
|
||||
test-multi-tenancy:
|
||||
name: Test multi tenancy with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run multi tenancy test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_multi_tenancy.py
|
||||
|
||||
test-graph-edges:
|
||||
name: Test graph edge ingestion
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -384,7 +308,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres redis"
|
||||
|
||||
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres/Redis)
|
||||
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres)
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -397,7 +321,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'redis'
|
||||
SHARED_KUZU_LOCK: true
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
|
|
@ -463,37 +386,8 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||
|
||||
test-edge-centered-payload:
|
||||
name: Test Cognify - Edge Centered Payload
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Edge Centered Payload Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
TRIPLET_EMBEDDING: True
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_edge_centered_payload.py
|
||||
|
||||
run_conversation_sessions_test_redis:
|
||||
name: Conversation sessions test (Redis)
|
||||
run_conversation_sessions_test:
|
||||
name: Conversation sessions test
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
|
|
@ -533,60 +427,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres redis"
|
||||
|
||||
- name: Run Conversation session tests (Redis)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'redis'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
DB_PORT: 5432
|
||||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
run_conversation_sessions_test_fs:
|
||||
name: Conversation sessions test (FS)
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
POSTGRES_USER: cognee
|
||||
POSTGRES_PASSWORD: cognee
|
||||
POSTGRES_DB: cognee_db
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "postgres"
|
||||
|
||||
- name: Run Conversation session tests (FS)
|
||||
- name: Run Conversation session tests
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -599,7 +440,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
CACHING: true
|
||||
CACHE_BACKEND: 'fs'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
@ -607,30 +447,3 @@ jobs:
|
|||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
run-pipeline-cache-test:
|
||||
name: Test Pipeline Caching
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Pipeline Cache Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||
|
|
|
|||
37
.github/workflows/examples_tests.yml
vendored
37
.github/workflows/examples_tests.yml
vendored
|
|
@ -21,7 +21,6 @@ jobs:
|
|||
|
||||
- name: Run Multimedia Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: uv run python ./examples/python/multimedia_example.py
|
||||
|
|
@ -41,7 +40,6 @@ jobs:
|
|||
|
||||
- name: Run Evaluation Framework Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -71,8 +69,6 @@ jobs:
|
|||
|
||||
- name: Run Descriptive Graph Metrics Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -103,7 +99,6 @@ jobs:
|
|||
|
||||
- name: Run Dynamic Steps Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -129,7 +124,6 @@ jobs:
|
|||
|
||||
- name: Run Temporal Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -155,7 +149,6 @@ jobs:
|
|||
|
||||
- name: Run Ontology Demo Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -181,7 +174,6 @@ jobs:
|
|||
|
||||
- name: Run Agentic Reasoning Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -207,7 +199,6 @@ jobs:
|
|||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -219,32 +210,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/memify_coding_agent_example.py
|
||||
|
||||
test-custom-pipeline:
|
||||
name: Run Custom Pipeline Example
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Custom Pipeline Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/run_custom_pipeline_example.py
|
||||
|
||||
test-permissions-example:
|
||||
name: Run Permissions Example
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -259,7 +224,6 @@ jobs:
|
|||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -285,7 +249,6 @@ jobs:
|
|||
|
||||
- name: Run Docling Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
|
|||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,7 +78,6 @@ jobs:
|
|||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
70
.github/workflows/load_tests.yml
vendored
70
.github/workflows/load_tests.yml
vendored
|
|
@ -1,70 +0,0 @@
|
|||
name: Load tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
AWS_ACCESS_KEY_ID:
|
||||
required: true
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
test-load:
|
||||
name: Test Load
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Verify File Descriptor Limit
|
||||
run: ulimit -n
|
||||
|
||||
- name: Run Load Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: True
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
STORAGE_BACKEND: s3
|
||||
AWS_REGION: eu-west-1
|
||||
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
|
||||
run: uv run python ./cognee/tests/test_load.py
|
||||
|
||||
|
||||
22
.github/workflows/pre_test.yml
vendored
22
.github/workflows/pre_test.yml
vendored
|
|
@ -1,22 +0,0 @@
|
|||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
check-uv-lock:
|
||||
name: Validate uv lockfile and project metadata
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Validate uv lockfile and project metadata
|
||||
run: uv lock --check || { echo "'uv lock --check' failed."; echo "Run 'uv lock' and push your changes."; exit 1; }
|
||||
138
.github/workflows/release.yml
vendored
138
.github/workflows/release.yml
vendored
|
|
@ -1,138 +0,0 @@
|
|||
name: release.yml
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
flavour:
|
||||
required: true
|
||||
default: dev
|
||||
type: choice
|
||||
options:
|
||||
- dev
|
||||
- main
|
||||
description: Dev or Main release
|
||||
|
||||
jobs:
|
||||
release-github:
|
||||
name: Create GitHub Release from ${{ inputs.flavour }}
|
||||
outputs:
|
||||
tag: ${{ steps.create_tag.outputs.tag }}
|
||||
version: ${{ steps.create_tag.outputs.version }}
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Create and push git tag
|
||||
id: create_tag
|
||||
run: |
|
||||
VERSION="$(uv version --short)"
|
||||
TAG="v${VERSION}"
|
||||
|
||||
echo "Tag to create: ${TAG}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
git tag "${TAG}"
|
||||
git push origin "${TAG}"
|
||||
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.create_tag.outputs.tag }}
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
release-pypi-package:
|
||||
needs: release-github
|
||||
name: Release PyPI Package from ${{ inputs.flavour }}
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install Python
|
||||
run: uv python install
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --all-extras
|
||||
|
||||
- name: Build distributions
|
||||
run: uv build
|
||||
|
||||
- name: Publish ${{ inputs.flavour }} release to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: uv publish
|
||||
|
||||
release-docker-image:
|
||||
needs: release-github
|
||||
name: Release Docker Image from ${{ inputs.flavour }}
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Build and push Dev Docker Image
|
||||
if: ${{ inputs.flavour == 'dev' }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||
labels: |
|
||||
version=${{ needs.release-github.outputs.version }}
|
||||
flavour=${{ inputs.flavour }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
|
||||
- name: Build and push Main Docker Image
|
||||
if: ${{ inputs.flavour == 'main' }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||
cognee/cognee:latest
|
||||
labels: |
|
||||
version=${{ needs.release-github.outputs.version }}
|
||||
flavour=${{ inputs.flavour }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
17
.github/workflows/release_test.yml
vendored
17
.github/workflows/release_test.yml
vendored
|
|
@ -1,17 +0,0 @@
|
|||
# Long-running, heavy and resource-consuming tests for release validation
|
||||
name: Release Test Workflow
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
load-tests:
|
||||
name: Load Tests
|
||||
uses: ./.github/workflows/load_tests.yml
|
||||
secrets: inherit
|
||||
3
.github/workflows/search_db_tests.yml
vendored
3
.github/workflows/search_db_tests.yml
vendored
|
|
@ -84,7 +84,6 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
|
@ -136,7 +135,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
@ -199,7 +197,6 @@ jobs:
|
|||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,7 +72,6 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -124,7 +123,6 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -191,7 +189,6 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,6 @@ on:
|
|||
required: false
|
||||
type: string
|
||||
default: '["3.10.x", "3.12.x", "3.13.x"]'
|
||||
os:
|
||||
required: false
|
||||
type: string
|
||||
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
|
||||
secrets:
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
|
|
@ -44,11 +40,10 @@ jobs:
|
|||
run-unit-tests:
|
||||
name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -81,11 +76,10 @@ jobs:
|
|||
run-integration-tests:
|
||||
name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -118,11 +112,10 @@ jobs:
|
|||
run-library-test:
|
||||
name: Library test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -155,11 +148,10 @@ jobs:
|
|||
run-build-test:
|
||||
name: Build test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -185,11 +177,10 @@ jobs:
|
|||
run-soft-deletion-test:
|
||||
name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -223,11 +214,10 @@ jobs:
|
|||
run-hard-deletion-test:
|
||||
name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
90
.github/workflows/test_llms.yml
vendored
90
.github/workflows/test_llms.yml
vendored
|
|
@ -84,93 +84,3 @@ jobs:
|
|||
EMBEDDING_DIMENSIONS: "3072"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-api-key:
|
||||
name: Run Bedrock API Key Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Run Bedrock API Key Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-aws-credentials:
|
||||
name: Run Bedrock AWS Credentials Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Run Bedrock AWS Credentials Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-aws-profile:
|
||||
name: Run Bedrock AWS Profile Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Configure AWS Profile
|
||||
run: |
|
||||
mkdir -p ~/.aws
|
||||
cat > ~/.aws/credentials << EOF
|
||||
[bedrock-test]
|
||||
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
EOF
|
||||
|
||||
- name: Run Bedrock AWS Profile Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_PROFILE_NAME: "bedrock-test"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
17
.github/workflows/test_ollama.yml
vendored
17
.github/workflows/test_ollama.yml
vendored
|
|
@ -7,8 +7,13 @@ jobs:
|
|||
|
||||
run_ollama_test:
|
||||
|
||||
# needs 32 Gb RAM for phi4 in a container
|
||||
runs-on: buildjet-8vcpu-ubuntu-2204
|
||||
# needs 16 Gb RAM for phi4
|
||||
runs-on: buildjet-4vcpu-ubuntu-2204
|
||||
# services:
|
||||
# ollama:
|
||||
# image: ollama/ollama
|
||||
# ports:
|
||||
# - 11434:11434
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
@ -23,6 +28,14 @@ jobs:
|
|||
run: |
|
||||
uv add torch
|
||||
|
||||
# - name: Install ollama
|
||||
# run: curl -fsSL https://ollama.com/install.sh | sh
|
||||
# - name: Run ollama
|
||||
# run: |
|
||||
# ollama serve --openai &
|
||||
# ollama pull llama3.2 &
|
||||
# ollama pull avr/sfr-embedding-mistral:latest
|
||||
|
||||
- name: Start Ollama container
|
||||
run: |
|
||||
docker run -d --name ollama -p 11434:11434 ollama/ollama
|
||||
|
|
|
|||
33
.github/workflows/test_suites.yml
vendored
33
.github/workflows/test_suites.yml
vendored
|
|
@ -1,6 +1,4 @@
|
|||
name: Test Suites
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -18,21 +16,15 @@ env:
|
|||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
pre-test:
|
||||
name: basic checks
|
||||
uses: ./.github/workflows/pre_test.yml
|
||||
|
||||
jobs:
|
||||
basic-tests:
|
||||
name: Basic Tests
|
||||
uses: ./.github/workflows/basic_tests.yml
|
||||
needs: [ pre-test ]
|
||||
secrets: inherit
|
||||
|
||||
e2e-tests:
|
||||
name: End-to-End Tests
|
||||
uses: ./.github/workflows/e2e_tests.yml
|
||||
needs: [ pre-test ]
|
||||
secrets: inherit
|
||||
|
||||
distributed-tests:
|
||||
|
|
@ -88,22 +80,12 @@ jobs:
|
|||
uses: ./.github/workflows/notebooks_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
different-os-tests-basic:
|
||||
name: OS and Python Tests Ubuntu
|
||||
different-operating-systems-tests:
|
||||
name: Operating System and Python Tests
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
|
||||
os: '["ubuntu-22.04"]'
|
||||
secrets: inherit
|
||||
|
||||
different-os-tests-extended:
|
||||
name: OS and Python Tests Extended
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.13.x"]'
|
||||
os: '["macos-15", "windows-latest"]'
|
||||
secrets: inherit
|
||||
|
||||
# Matrix-based vector database tests
|
||||
|
|
@ -153,8 +135,7 @@ jobs:
|
|||
e2e-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
different-operating-systems-tests,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
llm-tests,
|
||||
|
|
@ -174,8 +155,7 @@ jobs:
|
|||
cli-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
different-operating-systems-tests,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
db-examples-tests,
|
||||
|
|
@ -196,8 +176,7 @@ jobs:
|
|||
"${{ needs.cli-tests.result }}" == "success" &&
|
||||
"${{ needs.graph-db-tests.result }}" == "success" &&
|
||||
"${{ needs.notebook-tests.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-basic.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-extended.result }}" == "success" &&
|
||||
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
|
||||
"${{ needs.vector-db-tests.result }}" == "success" &&
|
||||
"${{ needs.example-tests.result }}" == "success" &&
|
||||
"${{ needs.db-examples-tests.result }}" == "success" &&
|
||||
|
|
|
|||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,7 +92,6 @@ jobs:
|
|||
- name: Run PGVector Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -128,4 +127,4 @@ jobs:
|
|||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
36
.github/workflows/weighted_edges_tests.yml
vendored
36
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -2,7 +2,7 @@ name: Weighted Edges Tests
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, dev, weighted_edges ]
|
||||
branches: [ main, weighted_edges ]
|
||||
paths:
|
||||
- 'cognee/modules/graph/utils/get_graph_from_model.py'
|
||||
- 'cognee/infrastructure/engine/models/Edge.py'
|
||||
|
|
@ -10,7 +10,7 @@ on:
|
|||
- 'examples/python/weighted_edges_example.py'
|
||||
- '.github/workflows/weighted_edges_tests.yml'
|
||||
pull_request:
|
||||
branches: [ main, dev ]
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'cognee/modules/graph/utils/get_graph_from_model.py'
|
||||
- 'cognee/infrastructure/engine/models/Edge.py'
|
||||
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
|
|
@ -67,13 +67,14 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_ENDPOINT: https://api.openai.com/v1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_ENDPOINT: https://api.openai.com/v1/
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: text-embedding-3-small
|
||||
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
|
||||
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
EMBEDDING_API_VERSION: "2024-02-01"
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -94,7 +95,6 @@ jobs:
|
|||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
|
|
@ -108,14 +108,14 @@ jobs:
|
|||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: gpt-5-mini
|
||||
LLM_ENDPOINT: https://api.openai.com/v1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_ENDPOINT: https://api.openai.com/v1/
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: "2024-02-01"
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: text-embedding-3-small
|
||||
EMBEDDING_ENDPOINT: https://api.openai.com/v1/
|
||||
EMBEDDING_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
EMBEDDING_API_VERSION: "2024-02-01"
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -166,3 +166,5 @@ jobs:
|
|||
uses: astral-sh/ruff-action@v2
|
||||
with:
|
||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||
|
||||
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
pull_request_rules:
|
||||
- name: Backport to main when backport_main label is set
|
||||
conditions:
|
||||
- label=backport_main
|
||||
- base=dev
|
||||
actions:
|
||||
backport:
|
||||
branches:
|
||||
- main
|
||||
|
|
@ -71,7 +71,7 @@ git clone https://github.com/<your-github-username>/cognee.git
|
|||
cd cognee
|
||||
```
|
||||
In case you are working on Vector and Graph Adapters
|
||||
1. Fork the [**cognee-community**](https://github.com/topoteretes/cognee-community) repository
|
||||
1. Fork the [**cognee**](https://github.com/topoteretes/cognee-community) repository
|
||||
2. Clone your fork:
|
||||
```shell
|
||||
git clone https://github.com/<your-github-username>/cognee-community.git
|
||||
|
|
@ -97,21 +97,6 @@ git checkout -b feature/your-feature-name
|
|||
python cognee/cognee/tests/test_library.py
|
||||
```
|
||||
|
||||
### Running Simple Example
|
||||
|
||||
Change .env.example into .env and provide your OPENAI_API_KEY as LLM_API_KEY
|
||||
|
||||
Make sure to run ```shell uv sync ``` in the root cloned folder or set up a virtual environment to run cognee
|
||||
|
||||
```shell
|
||||
python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
uv run python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
|
||||
## 4. 📤 Submitting Changes
|
||||
|
||||
1. Install ruff on your system
|
||||
|
|
|
|||
15
README.md
15
README.md
|
|
@ -66,10 +66,13 @@ Use your data to build personalized and dynamic memory for AI Agents. Cognee let
|
|||
## About Cognee
|
||||
|
||||
Cognee is an open-source tool and platform that transforms your raw data into persistent and dynamic AI memory for Agents. It combines vector search with graph databases to make your documents both searchable by meaning and connected by relationships.
|
||||
Cognee offers default memory creation and search which we describe bellow. But with Cognee you can build your own!
|
||||
|
||||
You can use Cognee in two ways:
|
||||
|
||||
### Cognee Open Source:
|
||||
1. [Self-host Cognee Open Source](https://docs.cognee.ai/getting-started/installation), which stores all data locally by default.
|
||||
2. [Connect to Cognee Cloud](https://platform.cognee.ai/), and get the same OSS stack on managed infrastructure for easier development and productionization.
|
||||
|
||||
### Cognee Open Source (self-hosted):
|
||||
|
||||
- Interconnects any type of data — including past conversations, files, images, and audio transcriptions
|
||||
- Replaces traditional RAG systems with a unified memory layer built on graphs and vectors
|
||||
|
|
@ -77,6 +80,11 @@ Cognee offers default memory creation and search which we describe bellow. But w
|
|||
- Provides Pythonic data pipelines for ingestion from 30+ data sources
|
||||
- Offers high customizability through user-defined tasks, modular pipelines, and built-in search endpoints
|
||||
|
||||
### Cognee Cloud (managed):
|
||||
- Hosted web UI dashboard
|
||||
- Automatic version updates
|
||||
- Resource usage analytics
|
||||
- GDPR compliant, enterprise-grade security
|
||||
|
||||
## Basic Usage & Feature Guide
|
||||
|
||||
|
|
@ -118,7 +126,6 @@ Now, run a minimal pipeline:
|
|||
```python
|
||||
import cognee
|
||||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -136,7 +143,7 @@ async def main():
|
|||
|
||||
# Display the results
|
||||
for result in results:
|
||||
pprint(result)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -87,6 +87,11 @@ db_engine = get_relational_engine()
|
|||
|
||||
print("Using database:", db_engine.db_uri)
|
||||
|
||||
if "sqlite" in db_engine.db_uri:
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
|
||||
run_sync(db_engine.create_database())
|
||||
|
||||
config.set_section_option(
|
||||
config.config_ini_section,
|
||||
"SQLALCHEMY_DATABASE_URI",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Sequence, Union
|
|||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -27,34 +26,7 @@ def upgrade() -> None:
|
|||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus_enum = postgresql.ENUM(
|
||||
"STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
|
||||
)
|
||||
syncstatus_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
if "sync_operations" not in inspector.get_table_names():
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus = postgresql.ENUM(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
else:
|
||||
syncstatus = sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
# Table doesn't exist, create it normally
|
||||
op.create_table(
|
||||
"sync_operations",
|
||||
|
|
@ -62,7 +34,15 @@ def upgrade() -> None:
|
|||
sa.Column("run_id", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
syncstatus,
|
||||
sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("progress_percentage", sa.Integer(), nullable=True),
|
||||
|
|
|
|||
|
|
@ -1,333 +0,0 @@
|
|||
"""Expand dataset database with json connection field
|
||||
|
||||
Revision ID: 46a6ce2bd2b2
|
||||
Revises: 76625596c5c3
|
||||
Create Date: 2025-11-25 17:56:28.938931
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "46a6ce2bd2b2"
|
||||
down_revision: Union[str, None] = "76625596c5c3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||
TABLE_NAME = "dataset_database"
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||
|
||||
vector_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_connection_info"
|
||||
)
|
||||
if not vector_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
vector_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "vector_dataset_database_handler"
|
||||
)
|
||||
if not vector_dataset_database_handler:
|
||||
# Add LanceDB as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_connection_info"
|
||||
)
|
||||
if not graph_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
graph_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "graph_dataset_database_handler"
|
||||
)
|
||||
if not graph_dataset_database_handler:
|
||||
# Add Kuzu as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Drop the unique constraint to make unique=False
|
||||
graph_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == graph_constraint_name:
|
||||
graph_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
vector_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == vector_constraint_name:
|
||||
vector_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
if (
|
||||
vector_constraint_to_drop
|
||||
and graph_constraint_to_drop
|
||||
and op.get_context().dialect.name == "postgresql"
|
||||
):
|
||||
# PostgreSQL
|
||||
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
conn = op.get_bind()
|
||||
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||
rows = result.fetchall()
|
||||
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||
for row in unique_auto_indexes:
|
||||
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||
index_info = result.fetchall()
|
||||
if index_info[0][2] == "vector_database_name":
|
||||
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||
elif op.get_context().dialect.name == "postgresql":
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||
|
||||
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||
op.drop_column("dataset_database", "vector_dataset_database_handler")
|
||||
op.drop_column("dataset_database", "graph_dataset_database_handler")
|
||||
|
|
@ -23,8 +23,11 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
try:
|
||||
await_only(create_default_user())
|
||||
except UserAlreadyExists:
|
||||
pass # It's fine if the default user already exists
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
await_only(delete_user("default_user@example.com"))
|
||||
|
|
|
|||
|
|
@ -1,98 +0,0 @@
|
|||
"""Expand dataset database for multi user
|
||||
|
||||
Revision ID: 76625596c5c3
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-10-30 12:55:20.239562
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "76625596c5c3"
|
||||
down_revision: Union[str, None] = "c946955da633"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
vector_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_provider"
|
||||
)
|
||||
if not vector_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_provider"
|
||||
)
|
||||
if not graph_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
|
||||
if not vector_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
|
||||
if not graph_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
|
||||
if not vector_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
|
||||
if not graph_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("dataset_database", "vector_database_provider")
|
||||
op.drop_column("dataset_database", "graph_database_provider")
|
||||
op.drop_column("dataset_database", "vector_database_url")
|
||||
op.drop_column("dataset_database", "graph_database_url")
|
||||
op.drop_column("dataset_database", "vector_database_key")
|
||||
op.drop_column("dataset_database", "graph_database_key")
|
||||
|
|
@ -18,8 +18,11 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
db_engine = get_relational_engine()
|
||||
# we might want to delete this
|
||||
await_only(db_engine.create_database())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
db_engine = get_relational_engine()
|
||||
await_only(db_engine.delete_database())
|
||||
|
|
|
|||
|
|
@ -144,58 +144,44 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
|
|||
)
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
dataset_id_column = _get_column(insp, "acls", "dataset_id")
|
||||
if not dataset_id_column:
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
|
||||
if not datasets:
|
||||
return
|
||||
if not datasets:
|
||||
return
|
||||
|
||||
acl_list = []
|
||||
acl_list = []
|
||||
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(
|
||||
_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
|
||||
)
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
|
||||
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -1,137 +0,0 @@
|
|||
"""Multi Tenant Support
|
||||
|
||||
Revision ID: c946955da633
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-11-04 18:11:09.325158
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c946955da633"
|
||||
down_revision: Union[str, None] = "211ab850ef3d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _define_user_table() -> sa.Table:
|
||||
table = sa.Table(
|
||||
"users",
|
||||
sa.MetaData(),
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.UUID,
|
||||
sa.ForeignKey("principals.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _define_dataset_table() -> sa.Table:
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
table = sa.Table(
|
||||
"datasets",
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
|
||||
sa.Column("name", sa.Text),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
|
||||
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
dataset = _define_dataset_table()
|
||||
user = _define_user_table()
|
||||
|
||||
if "user_tenants" not in insp.get_table_names():
|
||||
# Define table with all necessary columns including primary key
|
||||
user_tenants = op.create_table(
|
||||
"user_tenants",
|
||||
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
|
||||
# Get all users with their tenant_id
|
||||
user_data = conn.execute(
|
||||
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# Insert into user_tenants table
|
||||
if user_data:
|
||||
op.bulk_insert(
|
||||
user_tenants,
|
||||
[
|
||||
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
|
||||
for user_id, tenant_id in user_data
|
||||
],
|
||||
)
|
||||
|
||||
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
|
||||
if not tenant_id_column:
|
||||
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
|
||||
|
||||
# Build subquery, select users.tenant_id for each dataset.owner_id
|
||||
tenant_id_from_dataset_owner = (
|
||||
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
|
||||
)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
# If column doesn't exist create new original_extension column and update from values of extension column
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.execute(
|
||||
dataset.update().values(
|
||||
tenant_id=tenant_id_from_dataset_owner,
|
||||
)
|
||||
)
|
||||
else:
|
||||
conn = op.get_bind()
|
||||
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
|
||||
|
||||
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user_tenants")
|
||||
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
|
||||
op.drop_column("datasets", "tenant_id")
|
||||
# ### end Alembic commands ###
|
||||
2512
cognee-frontend/package-lock.json
generated
2512
cognee-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -9,13 +9,13 @@
|
|||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@auth0/nextjs-auth0": "^4.13.1",
|
||||
"@auth0/nextjs-auth0": "^4.6.0",
|
||||
"classnames": "^2.5.1",
|
||||
"culori": "^4.0.1",
|
||||
"d3-force-3d": "^3.0.6",
|
||||
"next": "16.1.1",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"next": "15.3.3",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-force-graph-2d": "^1.27.1",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
|
|
@ -24,11 +24,11 @@
|
|||
"@tailwindcss/postcss": "^4.1.7",
|
||||
"@types/culori": "^4.0.0",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "^16.0.4",
|
||||
"eslint-config-next": "^15.3.3",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"tailwindcss": "^4.1.7",
|
||||
"typescript": "^5"
|
||||
|
|
|
|||
119
cognee-frontend/src/app/(graph)/CrewAITrigger.tsx
Normal file
119
cognee-frontend/src/app/(graph)/CrewAITrigger.tsx
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import { useState } from "react";
|
||||
import { fetch } from "@/utils";
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { CTAButton, Input } from "@/ui/elements";
|
||||
|
||||
interface CrewAIFormPayload extends HTMLFormElement {
|
||||
username1: HTMLInputElement;
|
||||
username2: HTMLInputElement;
|
||||
}
|
||||
|
||||
interface CrewAITriggerProps {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onData: (data: any) => void;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onActivity: (activities: any) => void;
|
||||
}
|
||||
|
||||
export default function CrewAITrigger({ onData, onActivity }: CrewAITriggerProps) {
|
||||
const [isCrewAIRunning, setIsCrewAIRunning] = useState(false);
|
||||
|
||||
const handleRunCrewAI = (event: React.FormEvent<CrewAIFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
const crewAIConfig = {
|
||||
username1: formElements.username1.value,
|
||||
username2: formElements.username2.value,
|
||||
};
|
||||
|
||||
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL;
|
||||
const wsUrl = backendApiUrl.replace(/^http(s)?/, "ws");
|
||||
|
||||
const websocket = new WebSocket(`${wsUrl}/v1/crewai/subscribe`);
|
||||
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Dispatching hiring crew agents" }]);
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.status === "PipelineRunActivity") {
|
||||
onActivity([data.payload]);
|
||||
return;
|
||||
}
|
||||
|
||||
onData({
|
||||
nodes: data.payload.nodes,
|
||||
links: data.payload.edges,
|
||||
});
|
||||
|
||||
const nodes_type_map: { [key: string]: number } = {};
|
||||
|
||||
for (let i = 0; i < data.payload.nodes.length; i++) {
|
||||
const node = data.payload.nodes[i];
|
||||
if (!nodes_type_map[node.type]) {
|
||||
nodes_type_map[node.type] = 0;
|
||||
}
|
||||
nodes_type_map[node.type] += 1;
|
||||
}
|
||||
|
||||
const activityMessage = Object.entries(nodes_type_map).reduce((message, [type, count]) => {
|
||||
return `${message}\n | ${type}: ${count}`;
|
||||
}, "Graph updated:");
|
||||
|
||||
onActivity([{
|
||||
id: uuid4(),
|
||||
timestamp: Date.now(),
|
||||
activity: activityMessage,
|
||||
}]);
|
||||
|
||||
if (data.status === "PipelineRunCompleted") {
|
||||
websocket.close();
|
||||
}
|
||||
};
|
||||
|
||||
onData(null);
|
||||
setIsCrewAIRunning(true);
|
||||
|
||||
return fetch("/v1/crewai/run", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(crewAIConfig),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents made a decision" }]);
|
||||
})
|
||||
.catch(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents had problems while executing" }]);
|
||||
})
|
||||
.finally(() => {
|
||||
websocket.close();
|
||||
setIsCrewAIRunning(false);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form className="w-full flex flex-col gap-2" onSubmit={handleRunCrewAI}>
|
||||
<h1 className="text-2xl text-white">Cognee Dev Mexican Standoff</h1>
|
||||
<span className="text-white">Agents compare GitHub profiles, and make a decision who is a better developer</span>
|
||||
<div className="flex flex-row gap-2">
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username1">GitHub username</label>
|
||||
<Input name="username1" type="text" placeholder="Github Username" required defaultValue="hajdul88" />
|
||||
</div>
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username2">GitHub username</label>
|
||||
<Input name="username2" type="text" placeholder="Github Username" required defaultValue="lxobr" />
|
||||
</div>
|
||||
</div>
|
||||
<CTAButton type="submit" disabled={isCrewAIRunning} className="whitespace-nowrap">
|
||||
Start Mexican Standoff
|
||||
{isCrewAIRunning && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import { NodeObject, LinkObject } from "react-force-graph-2d";
|
|||
import { ChangeEvent, useEffect, useImperativeHandle, useRef, useState } from "react";
|
||||
|
||||
import { DeleteIcon } from "@/ui/Icons";
|
||||
// import { FeedbackForm } from "@/ui/Partials";
|
||||
import { CTAButton, Input, NeutralButton, Select } from "@/ui/elements";
|
||||
|
||||
interface GraphControlsProps {
|
||||
|
|
@ -110,7 +111,7 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
};
|
||||
|
||||
const [isAuthShapeChangeEnabled, setIsAuthShapeChangeEnabled] = useState(true);
|
||||
const shapeChangeTimeout = useRef<number | null>(null);
|
||||
const shapeChangeTimeout = useRef<number | null>();
|
||||
|
||||
useEffect(() => {
|
||||
onGraphShapeChange(DEFAULT_GRAPH_SHAPE);
|
||||
|
|
@ -229,6 +230,12 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
)}
|
||||
</>
|
||||
{/* )} */}
|
||||
|
||||
{/* {selectedTab === "feedback" && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<FeedbackForm onSuccess={() => {}} />
|
||||
</div>
|
||||
)} */}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useRef, useState, RefObject } from "react";
|
||||
import { useCallback, useRef, useState, MutableRefObject } from "react";
|
||||
|
||||
import Link from "next/link";
|
||||
import { TextLogo } from "@/ui/App";
|
||||
|
|
@ -47,11 +47,11 @@ export default function GraphView() {
|
|||
updateData(newData);
|
||||
}, []);
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
|
||||
const graphControls = useRef<GraphControlsAPI>(null);
|
||||
const graphControls = useRef<GraphControlsAPI>();
|
||||
|
||||
const activityLog = useRef<ActivityLogAPI>(null);
|
||||
const activityLog = useRef<ActivityLogAPI>();
|
||||
|
||||
return (
|
||||
<main className="flex flex-col h-full">
|
||||
|
|
@ -74,18 +74,21 @@ export default function GraphView() {
|
|||
<div className="w-full h-full relative overflow-hidden">
|
||||
<GraphVisualization
|
||||
key={data?.nodes.length}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
data={data}
|
||||
graphControls={graphControls as RefObject<GraphControlsAPI>}
|
||||
graphControls={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
/>
|
||||
|
||||
<div className="absolute top-2 left-2 flex flex-col gap-2">
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CogneeAddWidget onData={onDataChange} />
|
||||
</div>
|
||||
{/* <div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CrewAITrigger onData={onDataChange} onActivity={(activities) => activityLog.current?.updateActivityLog(activities)} />
|
||||
</div> */}
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<h2 className="text-xl text-white mb-4">Activity Log</h2>
|
||||
<ActivityLog ref={activityLog as RefObject<ActivityLogAPI>} />
|
||||
<ActivityLog ref={activityLog as MutableRefObject<ActivityLogAPI>} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -93,7 +96,7 @@ export default function GraphView() {
|
|||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-110">
|
||||
<GraphControls
|
||||
data={data}
|
||||
ref={graphControls as RefObject<GraphControlsAPI>}
|
||||
ref={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
isAddNodeFormOpen={isAddNodeFormOpen}
|
||||
onFitIntoView={() => graphRef.current!.zoomToFit(1000, 50)}
|
||||
onGraphShapeChange={(shape) => graphRef.current!.setGraphShape(shape)}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import classNames from "classnames";
|
||||
import { RefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import dynamic from "next/dynamic";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
|
|
@ -16,9 +16,9 @@ const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
|||
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
interface GraphVisuzaliationProps {
|
||||
ref: RefObject<GraphVisualizationAPI>;
|
||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||
data?: GraphData<NodeObject, LinkObject>;
|
||||
graphControls: RefObject<GraphControlsAPI>;
|
||||
graphControls: MutableRefObject<GraphControlsAPI>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function handleDagError(loopNodeIds: (string | number)[]) {}
|
||||
|
||||
const graphRef = useRef<ForceGraphMethods>(null);
|
||||
const graphRef = useRef<ForceGraphMethods>();
|
||||
|
||||
useEffect(() => {
|
||||
if (data && graphRef.current) {
|
||||
|
|
@ -224,7 +224,6 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
) => {
|
||||
if (!graphRef.current) {
|
||||
console.warn("GraphVisualization: graphRef not ready yet");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return undefined as any;
|
||||
}
|
||||
|
||||
|
|
@ -240,7 +239,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
<ForceGraph
|
||||
ref={graphRef as RefObject<ForceGraphMethods>}
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"use client";
|
||||
"use server";
|
||||
|
||||
import Dashboard from "./Dashboard";
|
||||
|
||||
export default function Page() {
|
||||
export default async function Page() {
|
||||
const accessToken = "";
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
export { default } from "./dashboard/page";
|
||||
|
||||
export const dynamic = "force-dynamic";
|
||||
// export const dynamic = "force-dynamic";
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { NextResponse, type NextRequest } from "next/server";
|
|||
// import { auth0 } from "./modules/auth/auth0";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export async function proxy(request: NextRequest) {
|
||||
export async function middleware(request: NextRequest) {
|
||||
// if (process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "true") {
|
||||
// if (request.nextUrl.pathname === "/auth/token") {
|
||||
// return NextResponse.next();
|
||||
|
|
@ -13,6 +13,7 @@ export interface Dataset {
|
|||
|
||||
function useDatasets(useCloud = false) {
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
// const statusTimeout = useRef<any>(null);
|
||||
|
||||
// const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
||||
|
|
|
|||
69
cognee-frontend/src/ui/Partials/FeedbackForm.tsx
Normal file
69
cognee-frontend/src/ui/Partials/FeedbackForm.tsx
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { fetch, useBoolean } from "@/utils";
|
||||
import { CTAButton, TextArea } from "@/ui/elements";
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
feedback: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
interface FeedbackFormProps {
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export default function FeedbackForm({ onSuccess }: FeedbackFormProps) {
|
||||
const {
|
||||
value: isSubmittingFeedback,
|
||||
setTrue: disableFeedbackSubmit,
|
||||
setFalse: enableFeedbackSubmit,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [feedbackError, setFeedbackError] = useState<string | null>(null);
|
||||
|
||||
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
setFeedbackError(null);
|
||||
disableFeedbackSubmit();
|
||||
|
||||
fetch("/v1/crewai/feedback", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
feedback: formElements.feedback.value,
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onSuccess();
|
||||
formElements.feedback.value = "";
|
||||
})
|
||||
.catch(error => setFeedbackError(error.detail))
|
||||
.finally(() => enableFeedbackSubmit());
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="mb-4">
|
||||
<label className="block text-white" htmlFor="feedback">Feedback on agent's reasoning</label>
|
||||
<TextArea id="feedback" name="feedback" type="text" placeholder="Your feedback" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<CTAButton type="submit">
|
||||
<span>Submit feedback</span>
|
||||
{isSubmittingFeedback && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
|
||||
{feedbackError && (
|
||||
<span className="text-s text-white">{feedbackError}</span>
|
||||
)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
@ -3,3 +3,4 @@ export { default as Footer } from "./Footer/Footer";
|
|||
export { default as SearchView } from "./SearchView/SearchView";
|
||||
export { default as IFrameView } from "./IFrameView/IFrameView";
|
||||
// export { default as Explorer } from "./Explorer/Explorer";
|
||||
export { default as FeedbackForm } from "./FeedbackForm";
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import classNames from "classnames";
|
||||
import { Fragment, MouseEvent, RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Fragment, MouseEvent, MutableRefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import { useModal } from "@/ui/elements/Modal";
|
||||
import { CaretIcon, CloseIcon, PlusIcon } from "@/ui/Icons";
|
||||
|
|
@ -282,7 +282,7 @@ export default function Notebook({ notebook, updateNotebook, runCell }: Notebook
|
|||
function CellResult({ content }: { content: [] }) {
|
||||
const parsedContent = [];
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
const graphControls = useRef<GraphControlsAPI>({
|
||||
setSelectedNode: () => {},
|
||||
getSelectedNode: () => null,
|
||||
|
|
@ -298,7 +298,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<GraphVisualization
|
||||
data={transformInsightsGraphData(line)}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -346,7 +346,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -377,7 +377,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ export default function NotebookCellHeader({
|
|||
setFalse: setIsNotRunningCell,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [runInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
const [runInstance, setRunInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
|
||||
const handleCellRun = () => {
|
||||
if (runCell) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "react-jsx",
|
||||
"jsx": "preserve",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
|
|
@ -32,8 +32,7 @@
|
|||
"next-env.d.ts",
|
||||
"**/*.ts",
|
||||
"**/*.tsx",
|
||||
".next/types/**/*.ts",
|
||||
".next/dev/types/**/*.ts"
|
||||
".next/types/**/*.ts"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
|
|
|
|||
|
|
@ -445,22 +445,16 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
|
|||
|
||||
- **cognify**: Turns your data into a structured knowledge graph and stores it in memory
|
||||
|
||||
- **cognee_add_developer_rules**: Ingest core developer rule files into memory
|
||||
|
||||
- **codify**: Analyse a code repository, build a code graph, stores it in memory
|
||||
|
||||
- **delete**: Delete specific data from a dataset (supports soft/hard deletion modes)
|
||||
|
||||
- **get_developer_rules**: Retrieve all developer rules that were generated based on previous interactions
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
|
||||
|
||||
- **list_data**: List all datasets and their data items with IDs for deletion operations
|
||||
|
||||
- **save_interaction**: Logs user-agent interactions and query-answer pairs
|
||||
- **delete**: Delete specific data from a dataset (supports soft/hard deletion modes)
|
||||
|
||||
- **prune**: Reset cognee for a fresh start (removes all data)
|
||||
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, SUMMARIES, CYPHER, and FEELING_LUCKY
|
||||
|
||||
- **cognify_status / codify_status**: Track pipeline progress
|
||||
|
||||
**Data Management Examples:**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "cognee-mcp"
|
||||
version = "0.5.0"
|
||||
version = "0.4.0"
|
||||
description = "Cognee MCP server"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
@ -9,7 +9,7 @@ dependencies = [
|
|||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,docs,neo4j]==0.5.0",
|
||||
"cognee[postgres,docs,neo4j]==0.3.7",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class CogneeClient:
|
|||
query_type: str,
|
||||
datasets: Optional[List[str]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
top_k: int = 10,
|
||||
) -> Any:
|
||||
"""
|
||||
Search the knowledge graph.
|
||||
|
|
@ -192,7 +192,7 @@ class CogneeClient:
|
|||
|
||||
with redirect_stdout(sys.stderr):
|
||||
results = await self.cognee.search(
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text, top_k=top_k
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text
|
||||
)
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,97 @@ async def health_check(request):
|
|||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognee_add_developer_rules(
|
||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Ingest core developer rule files into Cognee's memory layer.
|
||||
|
||||
This function loads a predefined set of developer-related configuration,
|
||||
rule, and documentation files from the base repository and assigns them
|
||||
to the special 'developer_rules' node set in Cognee. It ensures these
|
||||
foundational files are always part of the structured memory graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_path : str
|
||||
Root path to resolve relative file paths. Defaults to current directory.
|
||||
|
||||
graph_model_file : str, optional
|
||||
Optional path to a custom schema file for knowledge graph generation.
|
||||
|
||||
graph_model_name : str, optional
|
||||
Optional class name to use from the graph_model_file schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A message indicating how many rule files were scheduled for ingestion,
|
||||
and how to check their processing status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Each file is processed asynchronously in the background.
|
||||
- Files are attached to the 'developer_rules' node set.
|
||||
- Missing files are skipped with a logged warning.
|
||||
"""
|
||||
|
||||
developer_rule_paths = [
|
||||
".cursorrules",
|
||||
".cursor/rules",
|
||||
".same/todos.md",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
"CLAUDE.md",
|
||||
".sourcegraph/memory.md",
|
||||
"AGENT.md",
|
||||
"AGENTS.md",
|
||||
]
|
||||
|
||||
async def cognify_task(file_path: str) -> None:
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
tasks = []
|
||||
for rel_path in developer_rule_paths:
|
||||
abs_path = os.path.join(base_path, rel_path)
|
||||
if os.path.isfile(abs_path):
|
||||
tasks.append(asyncio.create_task(cognify_task(abs_path)))
|
||||
else:
|
||||
logger.warning(f"Skipped missing developer rule file: {abs_path}")
|
||||
log_file = get_log_file_location()
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Started cognify for {len(tasks)} developer rule files in background.\n"
|
||||
f"All are added to the `developer_rules` node set.\n"
|
||||
f"Use `cognify_status` or check logs at {log_file} to monitor progress."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
|
|
@ -103,6 +194,7 @@ async def cognify(
|
|||
|
||||
Prerequisites:
|
||||
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
|
||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
||||
- **Vector Database**: Must be accessible for embeddings storage
|
||||
- **Graph Database**: Must be accessible for relationship storage
|
||||
|
||||
|
|
@ -316,7 +408,76 @@ async def save_interaction(data: str) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
||||
async def codify(repo_path: str) -> list:
|
||||
"""
|
||||
Analyze and generate a code-specific knowledge graph from a software repository.
|
||||
|
||||
This function launches a background task that processes the provided repository
|
||||
and builds a code knowledge graph. The function returns immediately while
|
||||
the processing continues in the background due to MCP timeout constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_path : str
|
||||
Path to the code repository to analyze. This can be a local file path or a
|
||||
relative path to a repository. The path should point to the root of the
|
||||
repository or a specific directory within it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with information about the
|
||||
background task launch and how to check its status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function launches a background task and returns immediately
|
||||
- The code graph generation may take significant time for larger repositories
|
||||
- Use the codify_status tool to check the progress of the operation
|
||||
- Process results are logged to the standard Cognee log file
|
||||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
asyncio.create_task(codify_task(repo_path))
|
||||
|
||||
log_file = get_log_file_location()
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current codify status use the codify_status tool\n"
|
||||
f"or you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -389,13 +550,6 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
|||
|
||||
The search_type is case-insensitive and will be converted to uppercase.
|
||||
|
||||
top_k : int, optional
|
||||
Maximum number of results to return (default: 10).
|
||||
Controls the amount of context retrieved from the knowledge graph.
|
||||
- Lower values (3-5): Faster, more focused results
|
||||
- Higher values (10-20): More comprehensive, but slower and more context-heavy
|
||||
Helps manage response size and context window usage in MCP clients.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
|
|
@ -432,32 +586,13 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
|||
|
||||
"""
|
||||
|
||||
async def search_task(search_query: str, search_type: str, top_k: int) -> str:
|
||||
"""
|
||||
Internal task to execute knowledge graph search with result formatting.
|
||||
|
||||
Handles the actual search execution and formats results appropriately
|
||||
for MCP clients based on the search type and execution mode (API vs direct).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
search_query : str
|
||||
The search query in natural language
|
||||
search_type : str
|
||||
Type of search to perform (GRAPH_COMPLETION, CHUNKS, etc.)
|
||||
top_k : int
|
||||
Maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Formatted search results as a string, with format depending on search_type
|
||||
"""
|
||||
async def search_task(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee_client.search(
|
||||
query_text=search_query, query_type=search_type, top_k=top_k
|
||||
query_text=search_query, query_type=search_type
|
||||
)
|
||||
|
||||
# Handle different result formats based on API vs direct mode
|
||||
|
|
@ -491,10 +626,49 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
|||
else:
|
||||
return str(search_results)
|
||||
|
||||
search_results = await search_task(search_query, search_type, top_k)
|
||||
search_results = await search_task(search_query, search_type)
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_developer_rules() -> list:
|
||||
"""
|
||||
Retrieve all developer rules that were generated based on previous interactions.
|
||||
|
||||
This tool queries the Cognee knowledge graph and returns a list of developer
|
||||
rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the retrieved developer rules.
|
||||
The format is plain text containing the developer rules in bulletpoints.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The specific logic for fetching rules is handled internally.
|
||||
- This tool does not accept any parameters and is intended for simple rule inspection use cases.
|
||||
"""
|
||||
|
||||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
rules_text = await fetch_rules_from_cognee()
|
||||
|
||||
return [types.TextContent(type="text", text=rules_text)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_data(dataset_id: str = None) -> list:
|
||||
"""
|
||||
|
|
@ -780,6 +954,48 @@ async def cognify_status():
|
|||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify_status():
|
||||
"""
|
||||
Get the current status of the codify pipeline.
|
||||
|
||||
This function retrieves information about current and recently completed codify operations
|
||||
in the codebase dataset. It provides details on progress, success/failure status, and statistics
|
||||
about the processed code repositories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the status information as a string.
|
||||
The status includes information about active and completed jobs for the cognify_code_pipeline.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
|
|
@ -880,10 +1096,6 @@ async def main():
|
|||
|
||||
# Skip migrations when in API mode (the API server handles its own database)
|
||||
if not args.no_migration and not args.api_url:
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
await setup()
|
||||
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Test client for Cognee MCP Server functionality.
|
||||
|
||||
This script tests all the tools and functions available in the Cognee MCP server,
|
||||
including cognify, search, prune, status checks, and utility functions.
|
||||
including cognify, codify, search, prune, status checks, and utility functions.
|
||||
|
||||
Usage:
|
||||
# Set your OpenAI API key first
|
||||
|
|
@ -23,7 +23,6 @@ import tempfile
|
|||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.shared.logging_utils import setup_logging
|
||||
from logging import ERROR, INFO
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
|
@ -36,7 +35,7 @@ from src.server import (
|
|||
load_class,
|
||||
)
|
||||
|
||||
# Set timeout for cognify to complete in
|
||||
# Set timeout for cognify/codify to complete in
|
||||
TIMEOUT = 5 * 60 # 5 min in seconds
|
||||
|
||||
|
||||
|
|
@ -152,9 +151,12 @@ DEBUG = True
|
|||
|
||||
expected_tools = {
|
||||
"cognify",
|
||||
"codify",
|
||||
"search",
|
||||
"prune",
|
||||
"cognify_status",
|
||||
"codify_status",
|
||||
"cognee_add_developer_rules",
|
||||
"list_data",
|
||||
"delete",
|
||||
}
|
||||
|
|
@ -245,6 +247,106 @@ DEBUG = True
|
|||
}
|
||||
print(f"❌ {test_name} test failed: {e}")
|
||||
|
||||
async def test_codify(self):
|
||||
"""Test the codify functionality using MCP client."""
|
||||
print("\n🧪 Testing codify functionality...")
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
codify_result = await session.call_tool(
|
||||
"codify", arguments={"repo_path": self.test_repo_dir}
|
||||
)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if codify processing is finished
|
||||
status_result = await session.call_tool("codify_status", arguments={})
|
||||
if hasattr(status_result, "content") and status_result.content:
|
||||
status_text = (
|
||||
status_result.content[0].text
|
||||
if status_result.content
|
||||
else str(status_result)
|
||||
)
|
||||
else:
|
||||
status_text = str(status_result)
|
||||
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Codify did not complete in 5min")
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["codify"] = {
|
||||
"status": "PASS",
|
||||
"result": codify_result,
|
||||
"message": "Codify executed successfully",
|
||||
}
|
||||
print("✅ Codify test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["codify"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Codify test failed",
|
||||
}
|
||||
print(f"❌ Codify test failed: {e}")
|
||||
|
||||
async def test_cognee_add_developer_rules(self):
|
||||
"""Test the cognee_add_developer_rules functionality using MCP client."""
|
||||
print("\n🧪 Testing cognee_add_developer_rules functionality...")
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
result = await session.call_tool(
|
||||
"cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
|
||||
)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if developer rule cognify processing is finished
|
||||
status_result = await session.call_tool("cognify_status", arguments={})
|
||||
if hasattr(status_result, "content") and status_result.content:
|
||||
status_text = (
|
||||
status_result.content[0].text
|
||||
if status_result.content
|
||||
else str(status_result)
|
||||
)
|
||||
else:
|
||||
status_text = str(status_result)
|
||||
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError(
|
||||
"Cognify of developer rules did not complete in 5min"
|
||||
)
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "Developer rules addition executed successfully",
|
||||
}
|
||||
print("✅ Developer rules test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Developer rules test failed",
|
||||
}
|
||||
print(f"❌ Developer rules test failed: {e}")
|
||||
|
||||
async def test_search_functionality(self):
|
||||
"""Test the search functionality with different search types using MCP client."""
|
||||
print("\n🧪 Testing search functionality...")
|
||||
|
|
@ -257,11 +359,7 @@ DEBUG = True
|
|||
# Go through all Cognee search types
|
||||
for search_type in SearchType:
|
||||
# Don't test these search types
|
||||
if search_type in [
|
||||
SearchType.NATURAL_LANGUAGE,
|
||||
SearchType.CYPHER,
|
||||
SearchType.TRIPLET_COMPLETION,
|
||||
]:
|
||||
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
|
||||
break
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
|
|
@ -583,6 +681,9 @@ class TestModel:
|
|||
test_name="Cognify2",
|
||||
)
|
||||
|
||||
await self.test_codify()
|
||||
await self.test_cognee_add_developer_rules()
|
||||
|
||||
# Test list_data and delete functionality
|
||||
await self.test_list_data()
|
||||
await self.test_delete()
|
||||
|
|
@ -638,5 +739,7 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from logging import ERROR
|
||||
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
7633
cognee-mcp/uv.lock
generated
7633
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -19,7 +19,6 @@ from .api.v1.add import add
|
|||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .modules.memify import memify
|
||||
from .modules.run_custom_pipeline import run_custom_pipeline
|
||||
from .api.v1.update import update
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
|
|
|
|||
|
|
@ -21,9 +21,8 @@ from cognee.api.v1.notebooks.routers import get_notebooks_router
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
|
|
@ -40,8 +39,6 @@ from cognee.api.v1.users.routers import (
|
|||
)
|
||||
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
|
||||
|
||||
# Ensure application logging is configured for container stdout/stderr
|
||||
setup_logging()
|
||||
logger = get_logger()
|
||||
|
||||
if os.getenv("ENV", "prod") == "prod":
|
||||
|
|
@ -77,9 +74,6 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
await get_default_user()
|
||||
|
||||
# Emit a clear startup message for docker logs
|
||||
logger.info("Backend server has started")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
|
@ -264,8 +258,6 @@ app.include_router(
|
|||
|
||||
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
|
||||
|
||||
app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
|
||||
|
||||
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
|
||||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
|
@ -278,6 +270,10 @@ app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["re
|
|||
|
||||
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix="/api/v1/users",
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
|
|
@ -205,7 +205,6 @@ async def add(
|
|||
pipeline_name="add_pipeline",
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
use_pipeline_cache=True,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -82,9 +82,7 @@ def get_add_router() -> APIRouter:
|
|||
datasetName,
|
||||
user=user,
|
||||
dataset_id=datasetId,
|
||||
node_set=node_set
|
||||
if node_set != [""]
|
||||
else None, # Transform default node_set endpoint value to None
|
||||
node_set=node_set if node_set else None,
|
||||
)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
|
|
|
|||
119
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
119
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
|
||||
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
logger = get_logger("code_graph_pipeline")
|
||||
|
||||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(
|
||||
repo_path,
|
||||
include_docs=False,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
supported_languages: Optional[list[str]] = None,
|
||||
):
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
cognee_config = get_cognify_config()
|
||||
user = await get_default_user()
|
||||
detailed_extraction = True
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
get_repo_file_dependencies,
|
||||
detailed_extraction=detailed_extraction,
|
||||
supported_languages=supported_languages,
|
||||
excluded_paths=excluded_paths,
|
||||
),
|
||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config={"batch_size": 30}),
|
||||
]
|
||||
|
||||
if include_docs:
|
||||
# This tasks take a long time to complete
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
]
|
||||
|
||||
dataset_name = "codebase"
|
||||
|
||||
# Save dataset to database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
||||
)
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
search_results = await search(
|
||||
query_type=SearchType.CODE,
|
||||
query_text="How is Relationship weight calculated?",
|
||||
)
|
||||
|
||||
for file in search_results:
|
||||
print(file["name"])
|
||||
|
||||
logger = setup_logging(name="code_graph_pipeline")
|
||||
asyncio.run(main())
|
||||
|
|
@ -3,7 +3,6 @@ from pydantic import BaseModel
|
|||
from typing import Union, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
|
@ -20,6 +19,7 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -53,7 +53,6 @@ async def cognify(
|
|||
custom_prompt: Optional[str] = None,
|
||||
temporal_cognify: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
|
@ -79,11 +78,12 @@ async def cognify(
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
3. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
4. **Relationship Detection**: Discovers connections between entities
|
||||
5. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
6. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
|
||||
Graph Model Customization:
|
||||
The `graph_model` parameter allows custom knowledge structures:
|
||||
|
|
@ -224,7 +224,6 @@ async def cognify(
|
|||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
chunks_per_batch=chunks_per_batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
|
|
@ -239,7 +238,6 @@ async def cognify(
|
|||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
use_pipeline_cache=True,
|
||||
pipeline_name="cognify_pipeline",
|
||||
data_per_batch=data_per_batch,
|
||||
)
|
||||
|
|
@ -253,7 +251,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
config: Config = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
chunks_per_batch: int = 100,
|
||||
**kwargs,
|
||||
) -> list[Task]:
|
||||
if config is None:
|
||||
ontology_config = get_ontology_env_config()
|
||||
|
|
@ -275,11 +272,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
if chunks_per_batch is None:
|
||||
chunks_per_batch = 100
|
||||
|
||||
cognify_config = get_cognify_config()
|
||||
embed_triplets = cognify_config.triplet_embedding
|
||||
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
@ -291,17 +286,12 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
**kwargs,
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
embed_triplets=embed_triplets,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
@ -315,13 +305,14 @@ async def get_temporal_tasks(
|
|||
|
||||
The pipeline includes:
|
||||
1. Document classification.
|
||||
2. Document chunking with a specified or default chunk size.
|
||||
3. Event and timestamp extraction from chunks.
|
||||
4. Knowledge graph extraction from events.
|
||||
5. Batched insertion of data points.
|
||||
2. Dataset permission checks (requires "write" access).
|
||||
3. Document chunking with a specified or default chunk size.
|
||||
4. Event and timestamp extraction from chunks.
|
||||
5. Knowledge graph extraction from events.
|
||||
6. Batched insertion of data points.
|
||||
|
||||
Args:
|
||||
user (User, optional): The user requesting task execution.
|
||||
user (User, optional): The user requesting task execution, used for permission checks.
|
||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
||||
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
||||
|
|
@ -334,6 +325,7 @@ async def get_temporal_tasks(
|
|||
|
||||
temporal_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
90
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
90
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(InDTO):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(InDTO):
|
||||
query: str
|
||||
full_input: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
try:
|
||||
import cognee.api.v1.cognify.code_graph_pipeline
|
||||
except ModuleNotFoundError:
|
||||
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
|
||||
return None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""
|
||||
Run indexation on a code repository.
|
||||
|
||||
This endpoint processes a code repository to create a knowledge graph
|
||||
of the codebase structure, dependencies, and relationships.
|
||||
|
||||
## Request Parameters
|
||||
- **repo_path** (str): Path to the code repository
|
||||
- **include_docs** (bool): Whether to include documentation files (default: false)
|
||||
|
||||
## Response
|
||||
No content returned. Processing results are logged.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during indexation process
|
||||
"""
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
logger.info(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""
|
||||
Retrieve context from the code knowledge graph.
|
||||
|
||||
This endpoint searches the indexed code repository to find relevant
|
||||
context based on the provided query.
|
||||
|
||||
## Request Parameters
|
||||
- **query** (str): Search query for code context
|
||||
- **full_input** (str): Full input text for processing
|
||||
|
||||
## Response
|
||||
Returns a list of relevant code files and context as JSON.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during retrieval process
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
payload.full_input.replace("cognee ", "")
|
||||
if payload.full_input.startswith("cognee ")
|
||||
else payload.full_input
|
||||
)
|
||||
|
||||
retriever = CodeRetriever()
|
||||
retrieved_files = await retriever.get_context(query)
|
||||
|
||||
return json.dumps(retrieved_files, cls=JSONEncoder)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -41,11 +41,6 @@ class CognifyPayloadDTO(InDTO):
|
|||
custom_prompt: Optional[str] = Field(
|
||||
default="", description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
ontology_key: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
examples=[[]],
|
||||
description="Reference to one or more previously uploaded ontologies",
|
||||
)
|
||||
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
|
|
@ -73,7 +68,6 @@ def get_cognify_router() -> APIRouter:
|
|||
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
|
||||
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
|
||||
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
|
||||
- **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
|
||||
|
||||
## Response
|
||||
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
|
||||
|
|
@ -88,8 +82,7 @@ def get_cognify_router() -> APIRouter:
|
|||
{
|
||||
"datasets": ["research_papers", "documentation"],
|
||||
"run_in_background": false,
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
|
||||
"ontology_key": ["medical_ontology_v1"]
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -115,35 +108,13 @@ def get_cognify_router() -> APIRouter:
|
|||
)
|
||||
|
||||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||
from cognee.api.v1.ontologies.ontologies import OntologyService
|
||||
|
||||
try:
|
||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
config_to_use = None
|
||||
|
||||
if payload.ontology_key:
|
||||
ontology_service = OntologyService()
|
||||
ontology_contents = ontology_service.get_ontology_contents(
|
||||
payload.ontology_key, user
|
||||
)
|
||||
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
|
||||
RDFLibOntologyResolver,
|
||||
)
|
||||
from io import StringIO
|
||||
|
||||
ontology_streams = [StringIO(content) for content in ontology_contents]
|
||||
config_to_use: Config = {
|
||||
"ontology_config": {
|
||||
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
|
||||
}
|
||||
}
|
||||
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets,
|
||||
user,
|
||||
config=config_to_use,
|
||||
run_in_background=payload.run_in_background,
|
||||
custom_prompt=payload.custom_prompt,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
|
|||
},
|
||||
)
|
||||
|
||||
from cognee.modules.data.methods import delete_dataset
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
|
||||
dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
|
||||
|
||||
await delete_dataset(dataset[0])
|
||||
await delete_dataset(dataset)
|
||||
|
||||
@router.delete(
|
||||
"/{dataset_id}/data/{data_id}",
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
from .ontologies import OntologyService
|
||||
from .routers.get_ontology_router import get_ontology_router
|
||||
|
||||
__all__ = ["OntologyService", "get_ontology_router"]
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyMetadata:
|
||||
ontology_key: str
|
||||
filename: str
|
||||
size_bytes: int
|
||||
uploaded_at: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class OntologyService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return Path(tempfile.gettempdir()) / "ontologies"
|
||||
|
||||
def _get_user_dir(self, user_id: str) -> Path:
|
||||
user_dir = self.base_dir / str(user_id)
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir
|
||||
|
||||
def _get_metadata_path(self, user_dir: Path) -> Path:
|
||||
return user_dir / "metadata.json"
|
||||
|
||||
def _load_metadata(self, user_dir: Path) -> dict:
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, user_dir: Path, metadata: dict):
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
async def upload_ontology(
|
||||
self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None
|
||||
) -> OntologyMetadata:
|
||||
if not file.filename:
|
||||
raise ValueError("File must have a filename")
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError("File must be in .owl format")
|
||||
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
if ontology_key in metadata:
|
||||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||
|
||||
content = await file.read()
|
||||
|
||||
file_path = user_dir / f"{ontology_key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": description,
|
||||
}
|
||||
metadata[ontology_key] = ontology_metadata
|
||||
self._save_metadata(user_dir, metadata)
|
||||
|
||||
return OntologyMetadata(
|
||||
ontology_key=ontology_key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
description=description,
|
||||
)
|
||||
|
||||
async def upload_ontologies(
|
||||
self,
|
||||
ontology_key: List[str],
|
||||
files: List[UploadFile],
|
||||
user,
|
||||
descriptions: Optional[List[str]] = None,
|
||||
) -> List[OntologyMetadata]:
|
||||
"""
|
||||
Upload ontology files with their respective keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of unique keys for each ontology
|
||||
files: List of UploadFile objects (same length as keys)
|
||||
user: Authenticated user
|
||||
descriptions: Optional list of descriptions for each file
|
||||
|
||||
Returns:
|
||||
List of OntologyMetadata objects for uploaded files
|
||||
|
||||
Raises:
|
||||
ValueError: If keys duplicate, file format invalid, or array lengths don't match
|
||||
"""
|
||||
if len(ontology_key) != len(files):
|
||||
raise ValueError("Number of keys must match number of files")
|
||||
|
||||
if len(set(ontology_key)) != len(ontology_key):
|
||||
raise ValueError("Duplicate ontology keys not allowed")
|
||||
|
||||
results = []
|
||||
|
||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||
results.append(
|
||||
await self.upload_ontology(
|
||||
ontology_key=key,
|
||||
file=file,
|
||||
user=user,
|
||||
description=descriptions[i] if descriptions else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||
"""
|
||||
Retrieve ontology content for one or more keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of ontology keys to retrieve (can contain single item)
|
||||
user: Authenticated user
|
||||
|
||||
Returns:
|
||||
List of ontology content strings
|
||||
|
||||
Raises:
|
||||
ValueError: If any ontology key not found
|
||||
"""
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
contents = []
|
||||
for key in ontology_key:
|
||||
if key not in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' not found")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"Ontology file for key '{key}' not found")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
contents.append(f.read())
|
||||
return contents
|
||||
|
||||
def list_ontologies(self, user) -> dict:
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
return self._load_metadata(user_dir)
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, List
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
from ..ontologies import OntologyService
|
||||
|
||||
|
||||
def get_ontology_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
ontology_service = OntologyService()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
async def upload_ontology(
|
||||
request: Request,
|
||||
ontology_key: str = Form(...),
|
||||
ontology_file: UploadFile = File(...),
|
||||
description: Optional[str] = Form(None),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Upload a single ontology file for later use in cognify operations.
|
||||
|
||||
## Request Parameters
|
||||
- **ontology_key** (str): User-defined identifier for the ontology.
|
||||
- **ontology_file** (UploadFile): Single OWL format ontology file
|
||||
- **description** (Optional[str]): Optional description for the ontology.
|
||||
|
||||
## Response
|
||||
Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology Upload API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Enforce: exactly one uploaded file for "ontology_file"
|
||||
form = await request.form()
|
||||
uploaded_files = form.getlist("ontology_file")
|
||||
if len(uploaded_files) != 1:
|
||||
raise ValueError("Only one ontology_file is allowed")
|
||||
|
||||
if ontology_key.strip().startswith(("[", "{")):
|
||||
raise ValueError("ontology_key must be a string")
|
||||
if description is not None and description.strip().startswith(("[", "{")):
|
||||
raise ValueError("description must be a string")
|
||||
|
||||
result = await ontology_service.upload_ontology(
|
||||
ontology_key=ontology_key,
|
||||
file=ontology_file,
|
||||
user=user,
|
||||
description=description,
|
||||
)
|
||||
|
||||
return {
|
||||
"uploaded_ontologies": [
|
||||
{
|
||||
"ontology_key": result.ontology_key,
|
||||
"filename": result.filename,
|
||||
"size_bytes": result.size_bytes,
|
||||
"uploaded_at": result.uploaded_at,
|
||||
"description": result.description,
|
||||
}
|
||||
]
|
||||
}
|
||||
except ValueError as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
@router.get("", response_model=dict)
|
||||
async def list_ontologies(user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
List all uploaded ontologies for the authenticated user.
|
||||
|
||||
## Response
|
||||
Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
|
||||
|
||||
## Error Codes
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology List API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
metadata = ontology_service.list_ontologies(user)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
return router
|
||||
|
|
@ -1,20 +1,15 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
|
||||
class SelectTenantDTO(InDTO):
|
||||
tenant_id: UUID | None = None
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
permissions_router = APIRouter()
|
||||
|
||||
|
|
@ -231,39 +226,4 @@ def get_permissions_router() -> APIRouter:
|
|||
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
|
||||
)
|
||||
|
||||
@permissions_router.post("/tenants/select")
|
||||
async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Select current tenant.
|
||||
|
||||
This endpoint selects a tenant with the specified UUID. Tenants are used
|
||||
to organize users and resources in multi-tenant environments, providing
|
||||
isolation and access control between different groups or organizations.
|
||||
|
||||
Sending a null/None value as tenant_id selects his default single user tenant
|
||||
|
||||
## Request Parameters
|
||||
- **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
|
||||
|
||||
## Response
|
||||
Returns a success message along with selected tenant id.
|
||||
"""
|
||||
send_telemetry(
|
||||
"Permissions API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
|
||||
"tenant_id": str(payload.tenant_id),
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
|
||||
|
||||
await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
|
||||
)
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
|
|
@ -31,9 +31,6 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
verbose: bool = False,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -124,8 +121,6 @@ async def search(
|
|||
|
||||
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
|
||||
|
||||
verbose: If True, returns detailed result information including graph representation (when possible).
|
||||
|
||||
Returns:
|
||||
list: Search results in format determined by query_type:
|
||||
|
||||
|
|
@ -205,9 +200,6 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -1,360 +0,0 @@
|
|||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_nvm_dir() -> Path:
|
||||
"""
|
||||
Get the nvm directory path following standard nvm installation logic.
|
||||
Uses XDG_CONFIG_HOME if set, otherwise falls back to ~/.nvm.
|
||||
"""
|
||||
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
||||
if xdg_config_home:
|
||||
return Path(xdg_config_home) / "nvm"
|
||||
return Path.home() / ".nvm"
|
||||
|
||||
|
||||
def get_nvm_sh_path() -> Path:
|
||||
"""
|
||||
Get the path to nvm.sh following standard nvm installation logic.
|
||||
"""
|
||||
return get_nvm_dir() / "nvm.sh"
|
||||
|
||||
|
||||
def check_nvm_installed() -> bool:
|
||||
"""
|
||||
Check if nvm (Node Version Manager) is installed.
|
||||
"""
|
||||
try:
|
||||
# Check if nvm is available in the shell
|
||||
# nvm is typically sourced in shell config files, so we need to check via shell
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, nvm-windows uses a different approach
|
||||
result = subprocess.run(
|
||||
["nvm", "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, nvm is a shell function, so we need to source it
|
||||
# First check if nvm.sh exists
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.debug(f"nvm.sh not found at {nvm_path}")
|
||||
return False
|
||||
|
||||
# Try to source nvm and check version, capturing errors
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && nvm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Log the error to help diagnose configuration issues
|
||||
if result.stderr:
|
||||
logger.debug(f"nvm check failed: {result.stderr.strip()}")
|
||||
return False
|
||||
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception checking nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_nvm() -> bool:
|
||||
"""
|
||||
Install nvm (Node Version Manager) on Unix-like systems.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("nvm installation on Windows requires nvm-windows.")
|
||||
logger.error(
|
||||
"Please install nvm-windows manually from: https://github.com/coreybutler/nvm-windows"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Installing nvm (Node Version Manager)...")
|
||||
|
||||
try:
|
||||
# Download and install nvm
|
||||
nvm_install_script = "https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh"
|
||||
logger.info(f"Downloading nvm installer from {nvm_install_script}...")
|
||||
|
||||
response = requests.get(nvm_install_script, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create a temporary script file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
|
||||
f.write(response.text)
|
||||
install_script_path = f.name
|
||||
|
||||
try:
|
||||
# Make the script executable and run it
|
||||
os.chmod(install_script_path, 0o755)
|
||||
result = subprocess.run(
|
||||
["bash", install_script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ nvm installed successfully")
|
||||
# Source nvm in current shell session
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"nvm installation completed but nvm directory not found at {nvm_dir}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.error(f"nvm installation failed: {result.stderr}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up temporary script
|
||||
try:
|
||||
os.unlink(install_script_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to download nvm installer: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_node_with_nvm() -> bool:
|
||||
"""
|
||||
Install the latest Node.js version using nvm.
|
||||
Returns True if installation succeeds, False otherwise.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("Node.js installation via nvm on Windows requires nvm-windows.")
|
||||
logger.error("Please install Node.js manually from: https://nodejs.org/")
|
||||
return False
|
||||
|
||||
logger.info("Installing latest Node.js version using nvm...")
|
||||
|
||||
try:
|
||||
# Source nvm and install latest Node.js
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.error(f"nvm.sh not found at {nvm_path}. nvm may not be properly installed.")
|
||||
return False
|
||||
|
||||
nvm_source_cmd = f"source {nvm_path}"
|
||||
install_cmd = f"{nvm_source_cmd} && nvm install node"
|
||||
|
||||
result = subprocess.run(
|
||||
["bash", "-c", install_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout for Node.js installation
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ Node.js installed successfully via nvm")
|
||||
|
||||
# Set as default version
|
||||
use_cmd = f"{nvm_source_cmd} && nvm alias default node"
|
||||
subprocess.run(
|
||||
["bash", "-c", use_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Add nvm to PATH for current session
|
||||
# This ensures node/npm are available in subsequent commands
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
# Update PATH for current process
|
||||
nvm_bin = nvm_dir / "versions" / "node"
|
||||
# Find the latest installed version
|
||||
if nvm_bin.exists():
|
||||
versions = sorted(nvm_bin.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = os.environ.get("PATH", "")
|
||||
os.environ["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to install Node.js: {result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Timeout installing Node.js (this can take several minutes)")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing Node.js: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]: # (is_available, error_message)
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
If not available, attempts to install nvm and Node.js automatically.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js - try direct command first, then with nvm if needed
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
# If direct command fails, try with nvm sourced (in case nvm is installed but not in PATH)
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run node: {result.stderr.strip()}")
|
||||
if result.returncode != 0:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not installed. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Verify installation after automatic setup
|
||||
# Try with nvm sourced first
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(
|
||||
f"Failed to verify node after installation: {result.stderr.strip()}"
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
return (
|
||||
False,
|
||||
f"Node.js installation completed but node command is not available. Please restart your terminal or source {nvm_path}",
|
||||
)
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, if we just installed via nvm, we may need to source nvm
|
||||
# Try direct command first
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
# Try with nvm sourced
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not found. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Retry checking Node.js after installation
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
node_version = result.stdout.strip()
|
||||
# Check npm
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
npm_version = result.stdout.strip()
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
elif result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception retrying node/npm check: {str(e)}")
|
||||
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from .node_setup import get_nvm_sh_path
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def run_npm_command(cmd: List[str], cwd: Path, timeout: int = 300) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Run an npm command, ensuring nvm is sourced if needed (Unix-like systems only).
|
||||
Returns the CompletedProcess result.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, use shell=True for npm commands
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, try direct command first
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
# If it fails and nvm might be installed, try with nvm sourced
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
nvm_cmd = f"source {nvm_path} && {' '.join(cmd)}"
|
||||
result = subprocess.run(
|
||||
["bash", "-c", nvm_cmd],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"npm command failed with nvm: {result.stderr.strip()}")
|
||||
return result
|
||||
|
|
@ -15,8 +15,6 @@ import shutil
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.version import get_cognee_version
|
||||
from .node_setup import check_node_npm, get_nvm_dir, get_nvm_sh_path
|
||||
from .npm_utils import run_npm_command
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -287,6 +285,48 @@ def find_frontend_path() -> Optional[Path]:
|
|||
return None
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]:
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
Returns (is_available, error_message)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return False, "Node.js is not installed or not in PATH"
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
|
||||
|
||||
def install_frontend_dependencies(frontend_path: Path) -> bool:
|
||||
"""
|
||||
Install frontend dependencies if node_modules doesn't exist.
|
||||
|
|
@ -301,7 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||
|
||||
try:
|
||||
result = run_npm_command(["npm", "install"], frontend_path, timeout=300)
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Frontend dependencies installed successfully")
|
||||
|
|
@ -585,21 +642,6 @@ def start_ui(
|
|||
env["HOST"] = "localhost"
|
||||
env["PORT"] = str(port)
|
||||
|
||||
# If nvm is installed, ensure it's available in the environment
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if platform.system() != "Windows" and nvm_path.exists():
|
||||
# Add nvm to PATH for the subprocess
|
||||
nvm_dir = get_nvm_dir()
|
||||
# Find the latest Node.js version installed via nvm
|
||||
nvm_versions = nvm_dir / "versions" / "node"
|
||||
if nvm_versions.exists():
|
||||
versions = sorted(nvm_versions.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = env.get("PATH", "")
|
||||
env["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
# Start the development server
|
||||
logger.info(f"Starting frontend server at http://localhost:{port}")
|
||||
logger.info("This may take a moment to compile and start...")
|
||||
|
|
@ -617,26 +659,14 @@ def start_ui(
|
|||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, use bash with nvm sourced if available
|
||||
if nvm_path.exists():
|
||||
# Use bash to source nvm and run npm
|
||||
process = subprocess.Popen(
|
||||
["bash", "-c", f"source {nvm_path} && npm run dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream frontend output with prefix
|
||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
|
|
@ -97,13 +97,6 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
chunker_class = LangchainChunker
|
||||
except ImportError:
|
||||
fmt.warning("LangchainChunker not available, using TextChunker")
|
||||
elif args.chunker == "CsvChunker":
|
||||
try:
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
|
||||
chunker_class = CsvChunker
|
||||
except ImportError:
|
||||
fmt.warning("CsvChunker not available, using TextChunker")
|
||||
|
||||
result = await cognee.cognify(
|
||||
datasets=datasets,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
|
|||
]
|
||||
|
||||
# Chunker choices
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
|
||||
|
||||
# Output format choices
|
||||
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
|
@ -22,67 +19,6 @@ async def set_session_user_context_variable(user):
|
|||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_config()
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
||||
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
if graph_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if vector_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||
!= graph_db_config.graph_database_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||
!= vector_db_config.vector_db_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
return multi_user_support_possible()
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
@ -102,17 +38,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
"""
|
||||
|
||||
if not backend_access_control_enabled():
|
||||
base_config = get_base_config()
|
||||
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
# Ensure that all connection info is resolved properly
|
||||
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||
)
|
||||
|
|
@ -121,31 +56,19 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||
vector_config = {
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
"vector_db_key": dataset_database.vector_database_key,
|
||||
"vector_db_name": dataset_database.vector_database_name,
|
||||
"vector_db_url": os.path.join(
|
||||
databases_directory_path, dataset_database.vector_database_name
|
||||
),
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"graph_database_provider": dataset_database.graph_database_provider,
|
||||
"graph_database_url": dataset_database.graph_database_url,
|
||||
"graph_database_name": dataset_database.graph_database_name,
|
||||
"graph_database_key": dataset_database.graph_database_key,
|
||||
"graph_database_provider": "kuzu",
|
||||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
"graph_dataset_database_handler": "",
|
||||
"graph_database_port": "",
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml poetry.lock README.md /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
|
|
@ -35,16 +35,6 @@ class AnswerGeneratorExecutor:
|
|||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
############
|
||||
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
|
||||
if isinstance(retrieval_context, list):
|
||||
retrieval_context = await retriever.convert_retrieved_objects_to_context(
|
||||
triplets=retrieval_context
|
||||
)
|
||||
|
||||
if isinstance(search_results, str):
|
||||
search_results = [search_results]
|
||||
#############
|
||||
answer = {
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logger.info("Question answering started...")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -30,6 +31,7 @@ async def get_cascade_graph_tasks(
|
|||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
), # Extract text chunks based on the document type.
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ async def get_no_summary_tasks(
|
|||
ontology_file_path=None,
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks without summarization tasks."""
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
|
||||
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||
|
||||
|
|
@ -51,8 +51,8 @@ async def get_just_chunks_tasks(
|
|||
chunk_size: int = None, chunker=TextChunker, user=None
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks with only chunk extraction and data points addition."""
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
|
||||
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
|
|
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
|
|||
"EM",
|
||||
"f1",
|
||||
] # Use only 'correctness' for DirectLLM
|
||||
deepeval_model: str = "gpt-4o-mini"
|
||||
deepeval_model: str = "gpt-5-mini"
|
||||
|
||||
# Metrics params
|
||||
calculate_metrics: bool = True
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import modal
|
|||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
|
@ -9,9 +10,6 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
|
|||
from cognee.eval_framework.answer_generation.run_question_answering_module import (
|
||||
run_question_answering,
|
||||
)
|
||||
import pathlib
|
||||
from os import path
|
||||
from modal import Image
|
||||
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
|
||||
from cognee.eval_framework.metrics_dashboard import create_dashboard
|
||||
|
||||
|
|
@ -40,19 +38,22 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
|
|||
|
||||
app = modal.App("modal-run-eval")
|
||||
|
||||
image = Image.from_dockerfile(
|
||||
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
|
||||
force_build=False,
|
||||
).add_local_python_source("cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
max_containers=10,
|
||||
timeout=86400,
|
||||
volumes={"/data": vol},
|
||||
secrets=[modal.Secret.from_name("eval_secrets")],
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.copy_local_file("pyproject.toml", "pyproject.toml")
|
||||
.copy_local_file("poetry.lock", "poetry.lock")
|
||||
.env(
|
||||
{
|
||||
"ENV": os.getenv("ENV"),
|
||||
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
)
|
||||
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
|
||||
async def modal_run_eval(eval_params=None):
|
||||
"""Runs evaluation pipeline and returns combined metrics results."""
|
||||
if eval_params is None:
|
||||
|
|
@ -104,7 +105,18 @@ async def main():
|
|||
configs = [
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=25,
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="HotPotQA",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
answering_questions=True,
|
||||
evaluating_answers=True,
|
||||
calculate_metrics=True,
|
||||
dashboard=True,
|
||||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="TwoWikiMultiHop",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
@ -115,7 +127,7 @@ async def main():
|
|||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=25,
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="Musique",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CacheConfig(BaseSettings):
|
||||
|
|
@ -15,7 +15,6 @@ class CacheConfig(BaseSettings):
|
|||
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
|
||||
"""
|
||||
|
||||
cache_backend: Literal["redis", "fs"] = "fs"
|
||||
caching: bool = False
|
||||
shared_kuzu_lock: bool = False
|
||||
cache_host: str = "localhost"
|
||||
|
|
@ -29,7 +28,6 @@ class CacheConfig(BaseSettings):
|
|||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"cache_backend": self.cache_backend,
|
||||
"caching": self.caching,
|
||||
"shared_kuzu_lock": self.shared_kuzu_lock,
|
||||
"cache_host": self.cache_host,
|
||||
|
|
|
|||
|
|
@ -1,151 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time
|
||||
import threading
|
||||
import diskcache as dc
|
||||
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import (
|
||||
CacheConnectionError,
|
||||
SharedKuzuLockRequiresRedisError,
|
||||
)
|
||||
from cognee.infrastructure.files.storage.get_storage_config import get_storage_config
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("FSCacheAdapter")
|
||||
|
||||
|
||||
class FSCacheAdapter(CacheDBInterface):
|
||||
def __init__(self):
|
||||
default_key = "sessions_db"
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
cache_directory = os.path.join(data_root_directory, ".cognee_fs_cache", default_key)
|
||||
os.makedirs(cache_directory, exist_ok=True)
|
||||
self.cache = dc.Cache(directory=cache_directory)
|
||||
self.cache.expire()
|
||||
|
||||
logger.debug(f"FSCacheAdapter initialized with cache directory: {cache_directory}")
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Lock acquisition is not available for filesystem cache backend."""
|
||||
message = "Shared Kuzu lock requires Redis cache backend."
|
||||
logger.error(message)
|
||||
raise SharedKuzuLockRequiresRedisError()
|
||||
|
||||
def release_lock(self):
|
||||
"""Lock release is not available for filesystem cache backend."""
|
||||
message = "Shared Kuzu lock requires Redis cache backend."
|
||||
logger.error(message)
|
||||
raise SharedKuzuLockRequiresRedisError()
|
||||
|
||||
async def add_qa(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
):
|
||||
try:
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
|
||||
qa_entry = {
|
||||
"time": datetime.utcnow().isoformat(),
|
||||
"question": question,
|
||||
"context": context,
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
existing_value = self.cache.get(session_key)
|
||||
if existing_value is not None:
|
||||
value: list = json.loads(existing_value)
|
||||
value.append(qa_entry)
|
||||
else:
|
||||
value = [qa_entry]
|
||||
|
||||
self.cache.set(session_key, json.dumps(value), expire=ttl)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while adding Q&A to diskcache: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 5):
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
value = self.cache.get(session_key)
|
||||
if value is None:
|
||||
return None
|
||||
entries = json.loads(value)
|
||||
return entries[-last_n:] if len(entries) > last_n else entries
|
||||
|
||||
async def get_all_qas(self, user_id: str, session_id: str):
|
||||
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||
value = self.cache.get(session_key)
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(value)
|
||||
|
||||
async def close(self):
|
||||
if self.cache is not None:
|
||||
self.cache.expire()
|
||||
self.cache.close()
|
||||
|
||||
|
||||
async def main():
|
||||
adapter = FSCacheAdapter()
|
||||
session_id = "demo_session"
|
||||
user_id = "demo_user_id"
|
||||
|
||||
print("\nAdding sample Q/A pairs...")
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Basic DB context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"Historical context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print("\nLatest QA:")
|
||||
latest = await adapter.get_latest_qa(user_id, session_id)
|
||||
print(json.dumps(latest, indent=2))
|
||||
|
||||
print("\nLast 2 QAs:")
|
||||
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
|
||||
print(json.dumps(last_two, indent=2))
|
||||
|
||||
session_id = "session_expire_demo"
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"What is Redis?",
|
||||
"Database context",
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
|
||||
await adapter.add_qa(
|
||||
user_id,
|
||||
session_id,
|
||||
"Who created Redis?",
|
||||
"History context",
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print(await adapter.get_all_qas(user_id, session_id))
|
||||
|
||||
await adapter.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,11 +1,9 @@
|
|||
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
from cognee.infrastructure.databases.cache.fscache.FsCacheAdapter import FSCacheAdapter
|
||||
|
||||
config = get_cache_config()
|
||||
|
||||
|
|
@ -35,28 +33,20 @@ def create_cache_engine(
|
|||
|
||||
Returns:
|
||||
--------
|
||||
- CacheDBInterface: An instance of the appropriate cache adapter.
|
||||
- CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic
|
||||
"""
|
||||
if config.caching:
|
||||
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
||||
|
||||
if config.cache_backend == "redis":
|
||||
return RedisAdapter(
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
)
|
||||
elif config.cache_backend == "fs":
|
||||
return FSCacheAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cache backend: '{config.cache_backend}'. "
|
||||
f"Supported backends are: 'redis', 'fs'"
|
||||
)
|
||||
return RedisAdapter(
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
from .use_dataset_database_handler import use_dataset_database_handler
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph or vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||
|
||||
This method is invoked right before the application opens a connection for a given dataset.
|
||||
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||
|
||||
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||
|
||||
Typical behavior:
|
||||
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||
or api_url/api_key), return them as-is.
|
||||
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||
to the caller.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row from the relational database
|
||||
Returns:
|
||||
DatasetDatabase: Updated instance with resolved connection info
|
||||
"""
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||
"""
|
||||
Delete the graph or vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||
Neo4jAuraDevDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||
LanceDBDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura_dev": {
|
||||
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||
"handler_provider": "neo4j",
|
||||
},
|
||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
|
||||
|
||||
def use_dataset_database_handler(
|
||||
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||
):
|
||||
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||
"handler_instance": dataset_database_handler,
|
||||
"handler_provider": dataset_database_provider,
|
||||
}
|
||||
|
|
@ -148,19 +148,3 @@ class CacheConnectionError(CogneeConfigurationError):
|
|||
status_code: int = status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class SharedKuzuLockRequiresRedisError(CogneeConfigurationError):
|
||||
"""
|
||||
Raised when shared Kuzu locking is requested without configuring the Redis backend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = (
|
||||
"Shared Kuzu lock requires Redis cache backend. Configure Redis to enable shared Kuzu locking."
|
||||
),
|
||||
name: str = "SharedKuzuLockRequiresRedisError",
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ class GraphConfig(BaseSettings):
|
|||
- graph_database_username
|
||||
- graph_database_password
|
||||
- graph_database_port
|
||||
- graph_database_key
|
||||
- graph_file_path
|
||||
- graph_model
|
||||
- graph_topology
|
||||
|
|
@ -42,12 +41,10 @@ class GraphConfig(BaseSettings):
|
|||
graph_database_username: str = ""
|
||||
graph_database_password: str = ""
|
||||
graph_database_port: int = 123
|
||||
graph_database_key: str = ""
|
||||
graph_file_path: str = ""
|
||||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
graph_dataset_database_handler: str = "kuzu"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||
|
||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
|
|
@ -93,12 +90,10 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
|
|
@ -121,9 +116,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ def create_graph_engine(
|
|||
graph_database_username="",
|
||||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
@ -71,7 +69,6 @@ def create_graph_engine(
|
|||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
database_name=graph_database_name,
|
||||
)
|
||||
|
||||
if graph_database_provider == "neo4j":
|
||||
|
|
|
|||
|
|
@ -398,18 +398,3 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve nodes and edges filtered by the provided attribute criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
||||
are lists of attribute values to filter by.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Kuzu Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Kuzu instance
|
||||
|
||||
"""
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "kuzu":
|
||||
raise ValueError(
|
||||
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "kuzu",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||
)
|
||||
graph_file_path = os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
)
|
||||
graph_engine = create_graph_engine(
|
||||
graph_database_provider=dataset_database.graph_database_provider,
|
||||
graph_database_url=dataset_database.graph_database_url,
|
||||
graph_database_name=dataset_database.graph_database_name,
|
||||
graph_database_key=dataset_database.graph_database_key,
|
||||
graph_file_path=graph_file_path,
|
||||
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
graph_dataset_database_handler="",
|
||||
graph_database_port="",
|
||||
)
|
||||
await graph_engine.delete_graph()
|
||||
|
|
@ -12,7 +12,6 @@ from contextlib import asynccontextmanager
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
|
@ -1187,11 +1186,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
||||
tuples of (source_id, target_id, relationship_name, properties).
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1255,11 +1249,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
|
|
@ -1428,92 +1417,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
return formatted_nodes, formatted_edges
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
Returns:
|
||||
nodes: List[dict] -> Each dict includes "id" and all node properties
|
||||
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
if not all(isinstance(x, str) for x in target_ids):
|
||||
raise CogneeValidationError("target_ids must be a list of strings")
|
||||
|
||||
query = """
|
||||
MATCH (n:Node)-[r]->(m:Node)
|
||||
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
||||
RETURN n.id, {
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}, m.id, {
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}, r.relationship_name, r.properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
if not result:
|
||||
logger.info("No data returned for the supplied IDs")
|
||||
return [], []
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
||||
if n_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(n_props["properties"])
|
||||
n_props.update(additional_props)
|
||||
del n_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
||||
|
||||
if m_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(m_props["properties"])
|
||||
m_props.update(additional_props)
|
||||
del m_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
||||
|
||||
nodes_dict[n_id] = (n_id, n_props)
|
||||
nodes_dict[m_id] = (m_id, m_props)
|
||||
|
||||
edge_props = {}
|
||||
if r_props_raw:
|
||||
try:
|
||||
edge_props = json.loads(r_props_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
||||
|
||||
source_id = edge_props.get("source_node_id", n_id)
|
||||
target_id = edge_props.get("target_node_id", m_id)
|
||||
edges.append((source_id, target_id, r_type, edge_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
|
@ -2005,134 +1908,3 @@ class KuzuAdapter(GraphDBInterface):
|
|||
time_ids_list = [item[0] for item in time_nodes]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
|
||||
with keys: 'start_node', 'relationship_properties', 'end_node'.
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If offset or limit are negative.
|
||||
- Exception: Re-raises any exceptions from query execution.
|
||||
"""
|
||||
if offset < 0:
|
||||
raise ValueError(f"Offset must be non-negative, got {offset}")
|
||||
if limit < 0:
|
||||
raise ValueError(f"Limit must be non-negative, got {limit}")
|
||||
|
||||
query = """
|
||||
MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
|
||||
RETURN {
|
||||
start_node: {
|
||||
id: start_node.id,
|
||||
name: start_node.name,
|
||||
type: start_node.type,
|
||||
properties: start_node.properties
|
||||
},
|
||||
relationship_properties: {
|
||||
relationship_name: relationship.relationship_name,
|
||||
properties: relationship.properties
|
||||
},
|
||||
end_node: {
|
||||
id: end_node.id,
|
||||
name: end_node.name,
|
||||
type: end_node.type,
|
||||
properties: end_node.properties
|
||||
}
|
||||
} AS triplet
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute triplet query: {str(e)}")
|
||||
logger.error(f"Query: {query}")
|
||||
logger.error(f"Parameters: offset={offset}, limit={limit}")
|
||||
raise
|
||||
|
||||
triplets = []
|
||||
for idx, row in enumerate(results):
|
||||
try:
|
||||
if not row or len(row) == 0:
|
||||
logger.warning(f"Skipping empty row at index {idx} in triplet batch")
|
||||
continue
|
||||
|
||||
if not isinstance(row[0], dict):
|
||||
logger.warning(
|
||||
f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
|
||||
)
|
||||
continue
|
||||
|
||||
triplet = row[0]
|
||||
|
||||
if "start_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["start_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
|
||||
|
||||
if "relationship_properties" not in triplet:
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["relationship_properties"], dict):
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
|
||||
)
|
||||
continue
|
||||
|
||||
rel_props = triplet["relationship_properties"].copy()
|
||||
relationship_name = rel_props.get("relationship_name") or ""
|
||||
|
||||
if rel_props.get("properties"):
|
||||
try:
|
||||
parsed_props = json.loads(rel_props["properties"])
|
||||
if isinstance(parsed_props, dict):
|
||||
rel_props.update(parsed_props)
|
||||
del rel_props["properties"]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed relationship properties is not a dict for triplet at index {idx}"
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
|
||||
)
|
||||
|
||||
rel_props["relationship_name"] = relationship_name
|
||||
triplet["relationship_properties"] = rel_props
|
||||
|
||||
if "end_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["end_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
|
||||
|
||||
triplets.append(triplet)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
return triplets
|
||||
|
|
|
|||
|
|
@ -1,168 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
import base64
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User, DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||
|
||||
Improvements needed to be production ready:
|
||||
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||
a secret manager or a similar system should be used instead.
|
||||
|
||||
Quality of life improvements:
|
||||
- Allow configuration of different Neo4j Aura plans and regions.
|
||||
- Requests should be made async, currently a blocking requests library is used.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "neo4j":
|
||||
raise ValueError(
|
||||
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials and encryption
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||
)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(
|
||||
status_url, headers=headers
|
||||
) # TODO: Use async requests with httpx
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
|
||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "neo4j_aura_dev",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": encrypted_db_password_string,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||
In this case, decrypt the password stored in the database.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||
"""
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
graph_db_password = cipher.decrypt(
|
||||
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||
).decode()
|
||||
|
||||
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||
graph_db_password
|
||||
)
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
pass
|
||||
|
|
@ -8,7 +8,7 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple, Coroutine
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
|
|
@ -964,63 +964,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
This version uses a single Cypher query for efficiency.
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
query = """
|
||||
MATCH ()-[r]-()
|
||||
WHERE startNode(r).id IN $target_ids
|
||||
OR endNode(r).id IN $target_ids
|
||||
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
|
||||
RETURN
|
||||
properties(a) AS n_properties,
|
||||
properties(b) AS m_properties,
|
||||
type(r) AS type,
|
||||
properties(r) AS properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for record in result:
|
||||
n_props = record["n_properties"]
|
||||
m_props = record["m_properties"]
|
||||
r_props = record["properties"]
|
||||
r_type = record["type"]
|
||||
|
||||
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
|
||||
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
|
||||
|
||||
source_id = r_props.get("source_node_id", n_props["id"])
|
||||
target_id = r_props.get("target_node_id", m_props["id"])
|
||||
edges.append((source_id, target_id, r_type, r_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
|
|
@ -1527,25 +1470,3 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (start_node:`{BASE_LABEL}`)-[relationship]->(end_node:`{BASE_LABEL}`)
|
||||
RETURN start_node, properties(relationship) AS relationship_properties, end_node
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
|
||||
return results
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue