Compare commits
243 commits
ci_ollamat
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a505eef95b | ||
|
|
8cd3aab1ef | ||
|
|
8a96a351e2 | ||
|
|
39613997d6 | ||
|
|
a5fc6165c1 | ||
|
|
69fe35bdee | ||
|
|
be738df88a | ||
|
|
01a39dff22 | ||
|
|
53f96f3e29 | ||
|
|
b339529621 | ||
|
|
6a5ba70ced | ||
|
|
7ee36f883b | ||
|
|
5b42b21af5 | ||
|
|
1061258fde | ||
|
|
7019a91f7c | ||
|
|
e0644285d4 | ||
|
|
f1526a6660 | ||
|
|
d8d3844805 | ||
|
|
23d55a45d4 | ||
|
|
1724997683 | ||
|
|
eda9f26b2b | ||
|
|
cc41ef853c | ||
|
|
4caac4b8f0 | ||
|
|
b5949580de | ||
|
|
986b93fee4 | ||
|
|
31e491bc88 | ||
|
|
f2bc7ca992 | ||
|
|
dd9aad90cb | ||
|
|
c7e94e296e | ||
|
|
109b889b68 | ||
|
|
fdbee66985 | ||
|
|
077e0a4937 | ||
|
|
40e1b9bfed | ||
|
|
4424ebf3c9 | ||
|
|
96107b9f2f | ||
|
|
b42ea166a8 | ||
|
|
4936551a90 | ||
|
|
405c925b70 | ||
|
|
78028b819f | ||
|
|
67af8a7cb4 | ||
|
|
622f8fa79e | ||
|
|
14d9540d1b | ||
|
|
6b86f423ff | ||
|
|
821852f2c3 | ||
|
|
433170fe09 | ||
|
|
bad22ba26b | ||
|
|
69e36cc834 | ||
|
|
d5bd75c627 | ||
|
|
c94225f505 | ||
|
|
a6bc27afaa | ||
|
|
14ff94f269 | ||
|
|
116b6f1eeb | ||
|
|
a225d7fc61 | ||
|
|
127d9860df | ||
|
|
ede884e0b0 | ||
|
|
a337f4e54c | ||
|
|
bce6094010 | ||
|
|
c48b274571 | ||
|
|
3b8a607b5f | ||
|
|
7b3d997a06 | ||
|
|
59f8d12fa3 | ||
|
|
e211e66275 | ||
|
|
0f50c993ac | ||
|
|
248ba74592 | ||
|
|
af8c5bedcc | ||
|
|
46ddd4fd12 | ||
|
|
0a1ed79340 | ||
|
|
fe7e97be45 | ||
|
|
88f61f9bdb | ||
|
|
001fbe699e | ||
|
|
2ca194c28f | ||
|
|
d932ee4bd9 | ||
|
|
d0b914acaa | ||
|
|
49f7c5188c | ||
|
|
c04d255aca | ||
|
|
75fea8dcc8 | ||
|
|
7a3138edf8 | ||
|
|
40bbdd1ac7 | ||
|
|
52b0029fbf | ||
|
|
67a4d40257 | ||
|
|
2f572ae509 | ||
|
|
a66b2ceeca | ||
|
|
7deaa6e8e9 | ||
|
|
0c97a400b0 | ||
|
|
5d0586da28 | ||
|
|
d5bf5cf4e9 | ||
|
|
9571641199 | ||
|
|
4afde917a9 | ||
|
|
386e0c8234 | ||
|
|
7033b63cd4 | ||
|
|
6d1f1d183a | ||
|
|
ed0c0d1823 | ||
|
|
a50810240b | ||
|
|
c7a2138966 | ||
|
|
7d7f8a249a | ||
|
|
f1c5b9a55f | ||
|
|
fd84edeb74 | ||
|
|
8cad9ef225 | ||
|
|
45f32f8bfd | ||
|
|
1961efcc33 | ||
|
|
f4078d1247 | ||
|
|
5698c609f5 | ||
|
|
0d2e84f58e | ||
|
|
3288ef01a4 | ||
|
|
d4d190ac2b | ||
|
|
4e8b2ffc3e | ||
|
|
c7810e9fdb | ||
|
|
1282905888 | ||
|
|
2d45db9e0d | ||
|
|
92448767fe | ||
|
|
702cdb45be | ||
|
|
dbcb35a6da | ||
|
|
0ff836b6dd | ||
|
|
5fe6a17cfd | ||
|
|
5ee5ae294a | ||
|
|
d473ef12ae | ||
|
|
8e67471d1e | ||
|
|
c17f838034 | ||
|
|
2df84dba27 | ||
|
|
5cfc7b1761 | ||
|
|
e480acaa7c | ||
|
|
ba9ca46574 | ||
|
|
362aa8df5c | ||
|
|
2e493cea4c | ||
|
|
524e3e8232 | ||
|
|
7c9a78abea | ||
|
|
859e98b494 | ||
|
|
76d054b6a5 | ||
|
|
0bb4ece4d8 | ||
|
|
5ce1af8cc0 | ||
|
|
d81d63390f | ||
|
|
a0c5867977 | ||
|
|
7e0be8f167 | ||
|
|
7844b9a3a5 | ||
|
|
ed9b774448 | ||
|
|
0c825b96ff | ||
|
|
00b60aed6c | ||
|
|
45841824d0 | ||
|
|
ddf802ff54 | ||
|
|
aa8afefe8a | ||
|
|
c649900042 | ||
|
|
c1857a50fa | ||
|
|
f776f04ee0 | ||
|
|
0fd939ca2b | ||
|
|
8b61e1baa2 | ||
|
|
1ff6a72fc7 | ||
|
|
700362a233 | ||
|
|
ca271c5dbb | ||
|
|
02b9fa485c | ||
|
|
0fe16939c1 | ||
|
|
2f06c3a97e | ||
|
|
5a2a5f64d2 | ||
|
|
7c5a17ecb5 | ||
|
|
39c6eba571 | ||
|
|
cf9edf2663 | ||
|
|
69777ef0a5 | ||
|
|
0f8cec64d5 | ||
|
|
ff20f021cc | ||
|
|
e46c0c4f6c | ||
|
|
5f3b776406 | ||
|
|
2e02aafbae | ||
|
|
593f17fcdc | ||
|
|
9e652a3a93 | ||
|
|
4c6bed885e | ||
|
|
f22330a7b6 | ||
|
|
e0d48c043a | ||
|
|
64a3ee96c4 | ||
|
|
d97acba78e | ||
|
|
f732fbf55f | ||
|
|
3b78eb88bd | ||
|
|
d99a7fffdd | ||
|
|
a7d0132e16 | ||
|
|
a472db5db5 | ||
|
|
dd87acc004 | ||
|
|
efb9739ac6 | ||
|
|
83e0876066 | ||
|
|
0b20f13117 | ||
|
|
0800810713 | ||
|
|
68d81a9125 | ||
|
|
3fe354e34e | ||
|
|
a8950b244f | ||
|
|
f996ecffcc | ||
|
|
7360729db1 | ||
|
|
0a4b1068a2 | ||
|
|
4fcbc96e85 | ||
|
|
53532921cf | ||
|
|
a5aa58a7d7 | ||
|
|
900dd38625 | ||
|
|
3acb581bd0 | ||
|
|
d566541a4b | ||
|
|
6bb642d6b8 | ||
|
|
0176cd5a68 | ||
|
|
b017fcc8d0 | ||
|
|
a0a14e7ccc | ||
|
|
a06c58a8ae | ||
|
|
432d4a1578 | ||
|
|
a8706a2fd2 | ||
|
|
6a64023876 | ||
|
|
ac6c3ef9de | ||
|
|
cfc131307f | ||
|
|
4f5771230e | ||
|
|
41b844a31c | ||
|
|
20d49eeb76 | ||
|
|
bb8de7b336 | ||
|
|
011a7fb60b | ||
|
|
ed2d687135 | ||
|
|
322c134e7e | ||
|
|
487635b71b | ||
|
|
6192d49dd6 | ||
|
|
44f0498b75 | ||
|
|
7557e6f10b | ||
|
|
007c7d403e | ||
|
|
21b1f6b39c | ||
|
|
05c984f98f | ||
|
|
c069dd276e | ||
|
|
26c18e9db2 | ||
|
|
a6db2811d5 | ||
|
|
41b973a61e | ||
|
|
18e4bb48fd | ||
|
|
c481b87d58 | ||
|
|
72c20c256d | ||
|
|
1e531385a6 | ||
|
|
b1fdb71b1e | ||
|
|
e69ff7c855 | ||
|
|
4e03406cb6 | ||
|
|
0427586211 | ||
|
|
3472f62a84 | ||
|
|
a068f3536b | ||
|
|
79f5201d6a | ||
|
|
7ee6cc8eb8 | ||
|
|
4b0b9bfc53 | ||
|
|
28f28f06dd | ||
|
|
ce925615fe | ||
|
|
908d329127 | ||
|
|
70f3ced15a | ||
|
|
c3f0cb95da | ||
|
|
9c9395851c | ||
|
|
bbcd8baf3a | ||
|
|
813ee94836 | ||
|
|
c91d1ff0ae | ||
|
|
06a3458982 | ||
|
|
3821966d89 | ||
|
|
2a46208569 |
171 changed files with 17553 additions and 13283 deletions
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
# .coderabbit.yaml
|
||||
language: en
|
||||
early_access: false
|
||||
enable_free_tier: true
|
||||
reviews:
|
||||
profile: chill
|
||||
instructions: >-
|
||||
# Code Review Instructions
|
||||
|
||||
- Ensure the code follows best practices and coding standards.
|
||||
- For **Python** code, follow
|
||||
[PEP 20](https://www.python.org/dev/peps/pep-0020/) and
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161)
|
||||
standards.
|
||||
|
||||
# Documentation Review Instructions
|
||||
- Verify that documentation and comments are clear and comprehensive.
|
||||
- Verify that documentation and comments are free of spelling mistakes.
|
||||
|
||||
# Test Code Review Instructions
|
||||
- Ensure that test code is automated, comprehensive, and follows testing best practices.
|
||||
- Verify that all critical functionality is covered by tests.
|
||||
- Ensure that test code follow
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/d840ee9990e65f302ce2a8d78ebe73f6)
|
||||
|
||||
# Misc.
|
||||
- Confirm that the code meets the project's requirements and objectives.
|
||||
- Confirm that copyright years are up-to date whenever a file is changed.
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary_placeholder: '@coderabbitai summary'
|
||||
auto_title_placeholder: '@coderabbitai'
|
||||
review_status: true
|
||||
poem: false
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
changed_files_summary: true
|
||||
path_filters: ['!*.xc*/**', '!node_modules/**', '!dist/**', '!build/**', '!.git/**', '!venv/**', '!__pycache__/**']
|
||||
path_instructions:
|
||||
- path: README.md
|
||||
instructions: >-
|
||||
1. Consider the file 'README.md' the overview/introduction of the project.
|
||||
Also consider the 'README.md' file the first place to look for project documentation.
|
||||
|
||||
2. When reviewing the file 'README.md' it should be linted with help
|
||||
from the tools `markdownlint` and `languagetool`, pointing out any issues.
|
||||
|
||||
3. You may assume the file 'README.md' will contain GitHub flavor Markdown.
|
||||
- path: '**/*.py'
|
||||
instructions: >-
|
||||
When reviewing Python code for this project:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a general rule, undocumented function definitions and class definitions in the project's Python code are assumed incomplete. Please consider suggesting a short summary of the code for any of these incomplete definitions as docstrings when reviewing.
|
||||
- path: cognee/tests/*
|
||||
instructions: >-
|
||||
When reviewing test code:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance, pointing out any violations discovered.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a project rule, Python source files with names prefixed by the string "test_" and located in the project's "tests" directory are the project's unit-testing code. It is safe, albeit a heuristic, to assume these are considered part of the project's minimal acceptance testing unless a justifying exception to this assumption is documented.
|
||||
|
||||
6. As a project rule, any files without extensions and with names prefixed by either the string "check_" or the string "test_", and located in the project's "tests" directory, are the project's non-unit test code. "Non-unit test" in this context refers to any type of testing other than unit testing, such as (but not limited to) functional testing, style linting, regression testing, etc. It can also be assumed that non-unit testing code is usually written as Bash shell scripts.
|
||||
- path: requirements.txt
|
||||
instructions: >-
|
||||
* The project's own Python dependencies are recorded in 'requirements.txt' for production code.
|
||||
|
||||
* The project's testing-specific Python dependencies are recorded in 'tests/requirements.txt' and are used for testing the project.
|
||||
|
||||
* The project's documentation-specific Python dependencies are recorded in 'docs/requirements.txt' and are used only for generating Python-focused documentation for the project. 'docs/requirements.txt' may be absent if not applicable.
|
||||
|
||||
Consider these 'requirements.txt' files the records of truth regarding project dependencies.
|
||||
- path: .github/**
|
||||
instructions: >-
|
||||
* When the project is hosted on GitHub: All GitHub-specific configurations, templates, and tools should be found in the '.github' directory tree.
|
||||
|
||||
* 'actionlint' erroneously generates false positives when dealing with GitHub's `${{ ... }}` syntax in conditionals.
|
||||
|
||||
* 'actionlint' erroneously generates incorrect solutions when suggesting the removal of valid `${{ ... }}` syntax.
|
||||
abort_on_close: true
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
ignore_title_keywords: []
|
||||
labels: []
|
||||
drafts: false
|
||||
base_branches:
|
||||
- dev
|
||||
- main
|
||||
tools:
|
||||
shellcheck:
|
||||
enabled: true
|
||||
ruff:
|
||||
enabled: true
|
||||
configuration:
|
||||
extend_select:
|
||||
- E # Pycodestyle errors (style issues)
|
||||
- F # PyFlakes codes (logical errors)
|
||||
- W # Pycodestyle warnings
|
||||
- N # PEP 8 naming conventions
|
||||
ignore:
|
||||
- W191
|
||||
- W391
|
||||
- E117
|
||||
- D208
|
||||
line_length: 100
|
||||
dummy_variable_rgx: '^(_.*|junk|extra)$' # Variables starting with '_' or named 'junk' or 'extras', are considered dummy variables
|
||||
markdownlint:
|
||||
enabled: true
|
||||
yamllint:
|
||||
enabled: true
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
|
@ -97,6 +97,8 @@ DB_NAME=cognee_db
|
|||
|
||||
# Default (local file-based)
|
||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||
|
||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
|
|
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
|
|||
# Not needed if a cloud vector database is not used
|
||||
VECTOR_DB_URL=
|
||||
VECTOR_DB_KEY=
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||
|
||||
################################################################################
|
||||
# 🧩 Ontology resolver settings
|
||||
|
|
|
|||
5
.github/actions/cognee_setup/action.yml
vendored
5
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -10,6 +10,10 @@ inputs:
|
|||
description: "Additional extra dependencies to install (space-separated)"
|
||||
required: false
|
||||
default: ""
|
||||
rebuild-lockfile:
|
||||
description: "Whether to rebuild the uv lockfile"
|
||||
required: false
|
||||
default: "false"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
|
|
@ -26,6 +30,7 @@ runs:
|
|||
enable-cache: true
|
||||
|
||||
- name: Rebuild uv lockfile
|
||||
if: ${{ inputs.rebuild-lockfile == 'true' }}
|
||||
shell: bash
|
||||
run: |
|
||||
rm uv.lock
|
||||
|
|
|
|||
20
.github/release-drafter.yml
vendored
Normal file
20
.github/release-drafter.yml
vendored
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
name-template: 'v$NEXT_PATCH_VERSION'
|
||||
tag-template: 'v$NEXT_PATCH_VERSION'
|
||||
|
||||
categories:
|
||||
- title: 'Features'
|
||||
labels: ['feature', 'enhancement']
|
||||
- title: 'Bug Fixes'
|
||||
labels: ['bug', 'fix']
|
||||
- title: 'Maintenance'
|
||||
labels: ['chore', 'refactor', 'ci']
|
||||
|
||||
change-template: '- $TITLE (#$NUMBER) @$AUTHOR'
|
||||
template: |
|
||||
## What’s Changed
|
||||
|
||||
$CHANGES
|
||||
|
||||
## Contributors
|
||||
|
||||
$CONTRIBUTORS
|
||||
30
.github/workflows/basic_tests.yml
vendored
30
.github/workflows/basic_tests.yml
vendored
|
|
@ -197,33 +197,3 @@ jobs:
|
|||
|
||||
- name: Run Simple Examples
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
graph-tests:
|
||||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run Graph Tests
|
||||
run: uv run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph
|
||||
|
|
|
|||
2
.github/workflows/db_examples_tests.yml
vendored
2
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
|||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -142,6 +143,7 @@ jobs:
|
|||
- name: Run PGVector Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,6 +47,7 @@ jobs:
|
|||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
107
.github/workflows/e2e_tests.yml
vendored
107
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,6 +147,7 @@ jobs:
|
|||
- name: Run Deduplication Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
|
|
@ -211,6 +212,56 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||
|
||||
test-dataset-database-handler:
|
||||
name: Test dataset database handlers in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases handler test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||
|
||||
test-dataset-database-deletion:
|
||||
name: Test dataset database deletion in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases deletion test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_delete.py
|
||||
|
||||
test-permissions:
|
||||
name: Test permissions with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -412,6 +463,35 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||
|
||||
test-edge-centered-payload:
|
||||
name: Test Cognify - Edge Centered Payload
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Edge Centered Payload Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
TRIPLET_EMBEDDING: True
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_edge_centered_payload.py
|
||||
|
||||
run_conversation_sessions_test_redis:
|
||||
name: Conversation sessions test (Redis)
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -527,3 +607,30 @@ jobs:
|
|||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
run-pipeline-cache-test:
|
||||
name: Test Pipeline Caching
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Pipeline Cache Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||
|
|
|
|||
1
.github/workflows/examples_tests.yml
vendored
1
.github/workflows/examples_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Descriptive Graph Metrics Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,6 +78,7 @@ jobs:
|
|||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
22
.github/workflows/pre_test.yml
vendored
Normal file
22
.github/workflows/pre_test.yml
vendored
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
check-uv-lock:
|
||||
name: Validate uv lockfile and project metadata
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Validate uv lockfile and project metadata
|
||||
run: uv lock --check || { echo "'uv lock --check' failed."; echo "Run 'uv lock' and push your changes."; exit 1; }
|
||||
138
.github/workflows/release.yml
vendored
Normal file
138
.github/workflows/release.yml
vendored
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
name: release.yml
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
flavour:
|
||||
required: true
|
||||
default: dev
|
||||
type: choice
|
||||
options:
|
||||
- dev
|
||||
- main
|
||||
description: Dev or Main release
|
||||
|
||||
jobs:
|
||||
release-github:
|
||||
name: Create GitHub Release from ${{ inputs.flavour }}
|
||||
outputs:
|
||||
tag: ${{ steps.create_tag.outputs.tag }}
|
||||
version: ${{ steps.create_tag.outputs.version }}
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Create and push git tag
|
||||
id: create_tag
|
||||
run: |
|
||||
VERSION="$(uv version --short)"
|
||||
TAG="v${VERSION}"
|
||||
|
||||
echo "Tag to create: ${TAG}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
git tag "${TAG}"
|
||||
git push origin "${TAG}"
|
||||
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.create_tag.outputs.tag }}
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
release-pypi-package:
|
||||
needs: release-github
|
||||
name: Release PyPI Package from ${{ inputs.flavour }}
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install Python
|
||||
run: uv python install
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --all-extras
|
||||
|
||||
- name: Build distributions
|
||||
run: uv build
|
||||
|
||||
- name: Publish ${{ inputs.flavour }} release to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: uv publish
|
||||
|
||||
release-docker-image:
|
||||
needs: release-github
|
||||
name: Release Docker Image from ${{ inputs.flavour }}
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out ${{ inputs.flavour }}
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.flavour }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Build and push Dev Docker Image
|
||||
if: ${{ inputs.flavour == 'dev' }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||
labels: |
|
||||
version=${{ needs.release-github.outputs.version }}
|
||||
flavour=${{ inputs.flavour }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
|
||||
- name: Build and push Main Docker Image
|
||||
if: ${{ inputs.flavour == 'main' }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
cognee/cognee:${{ needs.release-github.outputs.version }}
|
||||
cognee/cognee:latest
|
||||
labels: |
|
||||
version=${{ needs.release-github.outputs.version }}
|
||||
flavour=${{ inputs.flavour }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -123,6 +124,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -189,6 +191,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
|
|||
90
.github/workflows/test_llms.yml
vendored
90
.github/workflows/test_llms.yml
vendored
|
|
@ -84,3 +84,93 @@ jobs:
|
|||
EMBEDDING_DIMENSIONS: "3072"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-api-key:
|
||||
name: Run Bedrock API Key Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Run Bedrock API Key Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-aws-credentials:
|
||||
name: Run Bedrock AWS Credentials Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Run Bedrock AWS Credentials Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
test-bedrock-aws-profile:
|
||||
name: Run Bedrock AWS Profile Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Configure AWS Profile
|
||||
run: |
|
||||
mkdir -p ~/.aws
|
||||
cat > ~/.aws/credentials << EOF
|
||||
[bedrock-test]
|
||||
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
EOF
|
||||
|
||||
- name: Run Bedrock AWS Profile Simple Example
|
||||
env:
|
||||
LLM_PROVIDER: "bedrock"
|
||||
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
LLM_MAX_TOKENS: "16384"
|
||||
AWS_PROFILE_NAME: "bedrock-test"
|
||||
AWS_REGION_NAME: "eu-west-1"
|
||||
EMBEDDING_PROVIDER: "bedrock"
|
||||
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||
EMBEDDING_DIMENSIONS: "1024"
|
||||
EMBEDDING_MAX_TOKENS: "8191"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
17
.github/workflows/test_ollama.yml
vendored
17
.github/workflows/test_ollama.yml
vendored
|
|
@ -7,13 +7,8 @@ jobs:
|
|||
|
||||
run_ollama_test:
|
||||
|
||||
# needs 16 Gb RAM for phi4
|
||||
runs-on: buildjet-4vcpu-ubuntu-2204
|
||||
# services:
|
||||
# ollama:
|
||||
# image: ollama/ollama
|
||||
# ports:
|
||||
# - 11434:11434
|
||||
# needs 32 Gb RAM for phi4 in a container
|
||||
runs-on: buildjet-8vcpu-ubuntu-2204
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
@ -28,14 +23,6 @@ jobs:
|
|||
run: |
|
||||
uv add torch
|
||||
|
||||
# - name: Install ollama
|
||||
# run: curl -fsSL https://ollama.com/install.sh | sh
|
||||
# - name: Run ollama
|
||||
# run: |
|
||||
# ollama serve --openai &
|
||||
# ollama pull llama3.2 &
|
||||
# ollama pull avr/sfr-embedding-mistral:latest
|
||||
|
||||
- name: Start Ollama container
|
||||
run: |
|
||||
docker run -d --name ollama -p 11434:11434 ollama/ollama
|
||||
|
|
|
|||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -18,15 +18,21 @@ env:
|
|||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
jobs:
|
||||
pre-test:
|
||||
name: basic checks
|
||||
uses: ./.github/workflows/pre_test.yml
|
||||
|
||||
basic-tests:
|
||||
name: Basic Tests
|
||||
uses: ./.github/workflows/basic_tests.yml
|
||||
needs: [ pre-test ]
|
||||
secrets: inherit
|
||||
|
||||
e2e-tests:
|
||||
name: End-to-End Tests
|
||||
uses: ./.github/workflows/e2e_tests.yml
|
||||
needs: [ pre-test ]
|
||||
secrets: inherit
|
||||
|
||||
distributed-tests:
|
||||
|
|
|
|||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,6 +92,7 @@ jobs:
|
|||
- name: Run PGVector Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -127,4 +128,4 @@ jobs:
|
|||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
|
|
|
|||
3
.github/workflows/weighted_edges_tests.yml
vendored
3
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -94,6 +94,7 @@ jobs:
|
|||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
|
|
@ -165,5 +166,3 @@ jobs:
|
|||
uses: astral-sh/ruff-action@v2
|
||||
with:
|
||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||
|
||||
|
||||
9
.mergify.yml
Normal file
9
.mergify.yml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pull_request_rules:
|
||||
- name: Backport to main when backport_main label is set
|
||||
conditions:
|
||||
- label=backport_main
|
||||
- base=dev
|
||||
actions:
|
||||
backport:
|
||||
branches:
|
||||
- main
|
||||
|
|
@ -71,7 +71,7 @@ git clone https://github.com/<your-github-username>/cognee.git
|
|||
cd cognee
|
||||
```
|
||||
In case you are working on Vector and Graph Adapters
|
||||
1. Fork the [**cognee**](https://github.com/topoteretes/cognee-community) repository
|
||||
1. Fork the [**cognee-community**](https://github.com/topoteretes/cognee-community) repository
|
||||
2. Clone your fork:
|
||||
```shell
|
||||
git clone https://github.com/<your-github-username>/cognee-community.git
|
||||
|
|
@ -97,6 +97,21 @@ git checkout -b feature/your-feature-name
|
|||
python cognee/cognee/tests/test_library.py
|
||||
```
|
||||
|
||||
### Running Simple Example
|
||||
|
||||
Change .env.example into .env and provide your OPENAI_API_KEY as LLM_API_KEY
|
||||
|
||||
Make sure to run ```shell uv sync ``` in the root cloned folder or set up a virtual environment to run cognee
|
||||
|
||||
```shell
|
||||
python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
uv run python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
|
||||
## 4. 📤 Submitting Changes
|
||||
|
||||
1. Install ruff on your system
|
||||
|
|
|
|||
163
README.md
163
README.md
|
|
@ -5,27 +5,27 @@
|
|||
|
||||
<br />
|
||||
|
||||
cognee - Memory for AI Agents in 6 lines of code
|
||||
Cognee - Accurate and Persistent AI Memory
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.youtube.com/watch?v=1bezuvLwJmw&t=2s">Demo</a>
|
||||
.
|
||||
<a href="https://cognee.ai">Learn more</a>
|
||||
<a href="https://docs.cognee.ai/">Docs</a>
|
||||
.
|
||||
<a href="https://cognee.ai">Learn More</a>
|
||||
·
|
||||
<a href="https://discord.gg/NQPKmU5CCg">Join Discord</a>
|
||||
·
|
||||
<a href="https://www.reddit.com/r/AIMemory/">Join r/AIMemory</a>
|
||||
.
|
||||
<a href="https://docs.cognee.ai/">Docs</a>
|
||||
.
|
||||
<a href="https://github.com/topoteretes/cognee-community">cognee community repo</a>
|
||||
<a href="https://github.com/topoteretes/cognee-community">Community Plugins & Add-ons</a>
|
||||
</p>
|
||||
|
||||
|
||||
[](https://GitHub.com/topoteretes/cognee/network/)
|
||||
[](https://GitHub.com/topoteretes/cognee/stargazers/)
|
||||
[](https://GitHub.com/topoteretes/cognee/commit/)
|
||||
[](https://github.com/topoteretes/cognee/tags/)
|
||||
[](https://github.com/topoteretes/cognee/tags/)
|
||||
[](https://pepy.tech/project/cognee)
|
||||
[](https://github.com/topoteretes/cognee/blob/main/LICENSE)
|
||||
[](https://github.com/topoteretes/cognee/graphs/contributors)
|
||||
|
|
@ -41,11 +41,7 @@
|
|||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Extract, Cognify, Load) pipelines.
|
||||
Use your data to build personalized and dynamic memory for AI Agents. Cognee lets you replace RAG with scalable and modular ECL (Extract, Cognify, Load) pipelines.
|
||||
|
||||
<p align="center">
|
||||
🌐 Available Languages
|
||||
|
|
@ -53,7 +49,7 @@ Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Ext
|
|||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=de">Deutsch</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=es">Español</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=fr">français</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=fr">Français</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=ja">日本語</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=ko">한국어</a> |
|
||||
<a href="https://www.readme-i18n.com/topoteretes/cognee?lang=pt">Português</a> |
|
||||
|
|
@ -67,73 +63,62 @@ Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Ext
|
|||
</div>
|
||||
</div>
|
||||
|
||||
## About Cognee
|
||||
|
||||
Cognee is an open-source tool and platform that transforms your raw data into persistent and dynamic AI memory for Agents. It combines vector search with graph databases to make your documents both searchable by meaning and connected by relationships.
|
||||
Cognee offers default memory creation and search which we describe bellow. But with Cognee you can build your own!
|
||||
|
||||
|
||||
## Get Started
|
||||
### Cognee Open Source:
|
||||
|
||||
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/12Vi9zID-M3fpKpKiaqDBvkk98ElkRPWy?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a>
|
||||
- Interconnects any type of data — including past conversations, files, images, and audio transcriptions
|
||||
- Replaces traditional RAG systems with a unified memory layer built on graphs and vectors
|
||||
- Reduces developer effort and infrastructure cost while improving quality and precision
|
||||
- Provides Pythonic data pipelines for ingestion from 30+ data sources
|
||||
- Offers high customizability through user-defined tasks, modular pipelines, and built-in search endpoints
|
||||
|
||||
|
||||
## About cognee
|
||||
## Basic Usage & Feature Guide
|
||||
|
||||
cognee works locally and stores your data on your device.
|
||||
Our hosted solution is just our deployment of OSS cognee on Modal, with the goal of making development and productionization easier.
|
||||
To learn more, [check out this short, end-to-end Colab walkthrough](https://colab.research.google.com/drive/12Vi9zID-M3fpKpKiaqDBvkk98ElkRPWy?usp=sharing) of Cognee's core features.
|
||||
|
||||
Self-hosted package:
|
||||
[](https://colab.research.google.com/drive/12Vi9zID-M3fpKpKiaqDBvkk98ElkRPWy?usp=sharing)
|
||||
|
||||
- Interconnects any kind of documents: past conversations, files, images, and audio transcriptions
|
||||
- Replaces RAG systems with a memory layer based on graphs and vectors
|
||||
- Reduces developer effort and cost, while increasing quality and precision
|
||||
- Provides Pythonic data pipelines that manage data ingestion from 30+ data sources
|
||||
- Is highly customizable with custom tasks, pipelines, and a set of built-in search endpoints
|
||||
## Quickstart
|
||||
|
||||
Hosted platform:
|
||||
- Includes a managed UI and a [hosted solution](https://www.cognee.ai)
|
||||
Let’s try Cognee in just a few lines of code. For detailed setup and configuration, see the [Cognee Docs](https://docs.cognee.ai/getting-started/installation#environment-configuration).
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.10 to 3.13
|
||||
|
||||
## Self-Hosted (Open Source)
|
||||
### Step 1: Install Cognee
|
||||
|
||||
|
||||
### 📦 Installation
|
||||
|
||||
You can install Cognee using either **pip**, **poetry**, **uv** or any other python package manager..
|
||||
|
||||
Cognee supports Python 3.10 to 3.12
|
||||
|
||||
#### With uv
|
||||
You can install Cognee with **pip**, **poetry**, **uv**, or your preferred Python package manager.
|
||||
|
||||
```bash
|
||||
uv pip install cognee
|
||||
```
|
||||
|
||||
Detailed instructions can be found in our [docs](https://docs.cognee.ai/getting-started/installation#environment-configuration)
|
||||
|
||||
### 💻 Basic Usage
|
||||
|
||||
#### Setup
|
||||
|
||||
```
|
||||
### Step 2: Configure the LLM
|
||||
```python
|
||||
import os
|
||||
os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
|
||||
|
||||
```
|
||||
Alternatively, create a `.env` file using our [template](https://github.com/topoteretes/cognee/blob/main/.env.template).
|
||||
|
||||
You can also set the variables by creating .env file, using our <a href="https://github.com/topoteretes/cognee/blob/main/.env.template">template.</a>
|
||||
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai/setup-configuration/llm-providers">documentation</a>
|
||||
To integrate other LLM providers, see our [LLM Provider Documentation](https://docs.cognee.ai/setup-configuration/llm-providers).
|
||||
|
||||
### Step 3: Run the Pipeline
|
||||
|
||||
#### Simple example
|
||||
Cognee will take your documents, generate a knowledge graph from them and then query the graph based on combined relationships.
|
||||
|
||||
|
||||
|
||||
##### Python
|
||||
|
||||
This script will run the default pipeline:
|
||||
Now, run a minimal pipeline:
|
||||
|
||||
```python
|
||||
import cognee
|
||||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -147,80 +132,72 @@ async def main():
|
|||
await cognee.memify()
|
||||
|
||||
# Query the knowledge graph
|
||||
results = await cognee.search("What does cognee do?")
|
||||
results = await cognee.search("What does Cognee do?")
|
||||
|
||||
# Display the results
|
||||
for result in results:
|
||||
print(result)
|
||||
pprint(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
||||
```
|
||||
Example output:
|
||||
```
|
||||
|
||||
As you can see, the output is generated from the document we previously stored in Cognee:
|
||||
|
||||
```bash
|
||||
Cognee turns documents into AI memory.
|
||||
|
||||
```
|
||||
##### Via CLI
|
||||
|
||||
Let's get the basics covered
|
||||
### Use the Cognee CLI
|
||||
|
||||
```
|
||||
As an alternative, you can get started with these essential commands:
|
||||
|
||||
```bash
|
||||
cognee-cli add "Cognee turns documents into AI memory."
|
||||
|
||||
cognee-cli cognify
|
||||
|
||||
cognee-cli search "What does cognee do?"
|
||||
cognee-cli search "What does Cognee do?"
|
||||
cognee-cli delete --all
|
||||
|
||||
```
|
||||
or run
|
||||
```
|
||||
|
||||
To open the local UI, run:
|
||||
```bash
|
||||
cognee-cli -ui
|
||||
```
|
||||
|
||||
## Demos & Examples
|
||||
|
||||
</div>
|
||||
See Cognee in action:
|
||||
|
||||
### Persistent Agent Memory
|
||||
|
||||
[Cognee Memory for LangGraph Agents](https://github.com/user-attachments/assets/e113b628-7212-4a2b-b288-0be39a93a1c3)
|
||||
|
||||
### Simple GraphRAG
|
||||
|
||||
[Watch Demo](https://github.com/user-attachments/assets/f2186b2e-305a-42b0-9c2d-9f4473f15df8)
|
||||
|
||||
### Cognee with Ollama
|
||||
|
||||
[Watch Demo](https://github.com/user-attachments/assets/39672858-f774-4136-b957-1e2de67b8981)
|
||||
|
||||
|
||||
### Hosted Platform
|
||||
## Community & Support
|
||||
|
||||
Get up and running in minutes with automatic updates, analytics, and enterprise security.
|
||||
### Contributing
|
||||
We welcome contributions from the community! Your input helps make Cognee better for everyone. See [`CONTRIBUTING.md`](CONTRIBUTING.md) to get started.
|
||||
|
||||
1. Sign up on [cogwit](https://www.cognee.ai)
|
||||
2. Add your API key to local UI and sync your data to Cogwit
|
||||
### Code of Conduct
|
||||
|
||||
We're committed to fostering an inclusive and respectful community. Read our [Code of Conduct](https://github.com/topoteretes/cognee/blob/main/CODE_OF_CONDUCT.md) for guidelines.
|
||||
|
||||
## Research & Citation
|
||||
|
||||
|
||||
## Demos
|
||||
|
||||
1. Cogwit Beta demo:
|
||||
|
||||
[Cogwit Beta](https://github.com/user-attachments/assets/fa520cd2-2913-4246-a444-902ea5242cb0)
|
||||
|
||||
2. Simple GraphRAG demo
|
||||
|
||||
[Simple GraphRAG demo](https://github.com/user-attachments/assets/d80b0776-4eb9-4b8e-aa22-3691e2d44b8f)
|
||||
|
||||
3. cognee with Ollama
|
||||
|
||||
[cognee with local models](https://github.com/user-attachments/assets/8621d3e8-ecb8-4860-afb2-5594f2ee17db)
|
||||
|
||||
|
||||
## Contributing
|
||||
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
|
||||
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
We are committed to making open source an enjoyable and respectful experience for our community. See <a href="https://github.com/topoteretes/cognee/blob/main/CODE_OF_CONDUCT.md"><code>CODE_OF_CONDUCT</code></a> for more information.
|
||||
|
||||
## Citation
|
||||
|
||||
We now have a paper you can cite:
|
||||
We recently published a research paper on optimizing knowledge graphs for LLM reasoning:
|
||||
|
||||
```bibtex
|
||||
@misc{markovic2025optimizinginterfaceknowledgegraphs,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,333 @@
|
|||
"""Expand dataset database with json connection field
|
||||
|
||||
Revision ID: 46a6ce2bd2b2
|
||||
Revises: 76625596c5c3
|
||||
Create Date: 2025-11-25 17:56:28.938931
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "46a6ce2bd2b2"
|
||||
down_revision: Union[str, None] = "76625596c5c3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||
TABLE_NAME = "dataset_database"
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||
|
||||
vector_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_connection_info"
|
||||
)
|
||||
if not vector_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
vector_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "vector_dataset_database_handler"
|
||||
)
|
||||
if not vector_dataset_database_handler:
|
||||
# Add LanceDB as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_connection_info"
|
||||
)
|
||||
if not graph_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
graph_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "graph_dataset_database_handler"
|
||||
)
|
||||
if not graph_dataset_database_handler:
|
||||
# Add Kuzu as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Drop the unique constraint to make unique=False
|
||||
graph_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == graph_constraint_name:
|
||||
graph_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
vector_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == vector_constraint_name:
|
||||
vector_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
if (
|
||||
vector_constraint_to_drop
|
||||
and graph_constraint_to_drop
|
||||
and op.get_context().dialect.name == "postgresql"
|
||||
):
|
||||
# PostgreSQL
|
||||
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
conn = op.get_bind()
|
||||
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||
rows = result.fetchall()
|
||||
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||
for row in unique_auto_indexes:
|
||||
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||
index_info = result.fetchall()
|
||||
if index_info[0][2] == "vector_database_name":
|
||||
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||
elif op.get_context().dialect.name == "postgresql":
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||
|
||||
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||
op.drop_column("dataset_database", "vector_dataset_database_handler")
|
||||
op.drop_column("dataset_database", "graph_dataset_database_handler")
|
||||
2522
cognee-frontend/package-lock.json
generated
2522
cognee-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -9,13 +9,13 @@
|
|||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@auth0/nextjs-auth0": "^4.6.0",
|
||||
"@auth0/nextjs-auth0": "^4.13.1",
|
||||
"classnames": "^2.5.1",
|
||||
"culori": "^4.0.1",
|
||||
"d3-force-3d": "^3.0.6",
|
||||
"next": "15.3.3",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"next": "16.1.1",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-force-graph-2d": "^1.27.1",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
|
|
@ -24,11 +24,11 @@
|
|||
"@tailwindcss/postcss": "^4.1.7",
|
||||
"@types/culori": "^4.0.0",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "^15.3.3",
|
||||
"eslint-config-next": "^16.0.4",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"tailwindcss": "^4.1.7",
|
||||
"typescript": "^5"
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
import { useState } from "react";
|
||||
import { fetch } from "@/utils";
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { CTAButton, Input } from "@/ui/elements";
|
||||
|
||||
interface CrewAIFormPayload extends HTMLFormElement {
|
||||
username1: HTMLInputElement;
|
||||
username2: HTMLInputElement;
|
||||
}
|
||||
|
||||
interface CrewAITriggerProps {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onData: (data: any) => void;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onActivity: (activities: any) => void;
|
||||
}
|
||||
|
||||
export default function CrewAITrigger({ onData, onActivity }: CrewAITriggerProps) {
|
||||
const [isCrewAIRunning, setIsCrewAIRunning] = useState(false);
|
||||
|
||||
const handleRunCrewAI = (event: React.FormEvent<CrewAIFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
const crewAIConfig = {
|
||||
username1: formElements.username1.value,
|
||||
username2: formElements.username2.value,
|
||||
};
|
||||
|
||||
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL;
|
||||
const wsUrl = backendApiUrl.replace(/^http(s)?/, "ws");
|
||||
|
||||
const websocket = new WebSocket(`${wsUrl}/v1/crewai/subscribe`);
|
||||
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Dispatching hiring crew agents" }]);
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.status === "PipelineRunActivity") {
|
||||
onActivity([data.payload]);
|
||||
return;
|
||||
}
|
||||
|
||||
onData({
|
||||
nodes: data.payload.nodes,
|
||||
links: data.payload.edges,
|
||||
});
|
||||
|
||||
const nodes_type_map: { [key: string]: number } = {};
|
||||
|
||||
for (let i = 0; i < data.payload.nodes.length; i++) {
|
||||
const node = data.payload.nodes[i];
|
||||
if (!nodes_type_map[node.type]) {
|
||||
nodes_type_map[node.type] = 0;
|
||||
}
|
||||
nodes_type_map[node.type] += 1;
|
||||
}
|
||||
|
||||
const activityMessage = Object.entries(nodes_type_map).reduce((message, [type, count]) => {
|
||||
return `${message}\n | ${type}: ${count}`;
|
||||
}, "Graph updated:");
|
||||
|
||||
onActivity([{
|
||||
id: uuid4(),
|
||||
timestamp: Date.now(),
|
||||
activity: activityMessage,
|
||||
}]);
|
||||
|
||||
if (data.status === "PipelineRunCompleted") {
|
||||
websocket.close();
|
||||
}
|
||||
};
|
||||
|
||||
onData(null);
|
||||
setIsCrewAIRunning(true);
|
||||
|
||||
return fetch("/v1/crewai/run", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(crewAIConfig),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents made a decision" }]);
|
||||
})
|
||||
.catch(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents had problems while executing" }]);
|
||||
})
|
||||
.finally(() => {
|
||||
websocket.close();
|
||||
setIsCrewAIRunning(false);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form className="w-full flex flex-col gap-2" onSubmit={handleRunCrewAI}>
|
||||
<h1 className="text-2xl text-white">Cognee Dev Mexican Standoff</h1>
|
||||
<span className="text-white">Agents compare GitHub profiles, and make a decision who is a better developer</span>
|
||||
<div className="flex flex-row gap-2">
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username1">GitHub username</label>
|
||||
<Input name="username1" type="text" placeholder="Github Username" required defaultValue="hajdul88" />
|
||||
</div>
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username2">GitHub username</label>
|
||||
<Input name="username2" type="text" placeholder="Github Username" required defaultValue="lxobr" />
|
||||
</div>
|
||||
</div>
|
||||
<CTAButton type="submit" disabled={isCrewAIRunning} className="whitespace-nowrap">
|
||||
Start Mexican Standoff
|
||||
{isCrewAIRunning && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
|
@ -6,7 +6,6 @@ import { NodeObject, LinkObject } from "react-force-graph-2d";
|
|||
import { ChangeEvent, useEffect, useImperativeHandle, useRef, useState } from "react";
|
||||
|
||||
import { DeleteIcon } from "@/ui/Icons";
|
||||
// import { FeedbackForm } from "@/ui/Partials";
|
||||
import { CTAButton, Input, NeutralButton, Select } from "@/ui/elements";
|
||||
|
||||
interface GraphControlsProps {
|
||||
|
|
@ -111,7 +110,7 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
};
|
||||
|
||||
const [isAuthShapeChangeEnabled, setIsAuthShapeChangeEnabled] = useState(true);
|
||||
const shapeChangeTimeout = useRef<number | null>();
|
||||
const shapeChangeTimeout = useRef<number | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
onGraphShapeChange(DEFAULT_GRAPH_SHAPE);
|
||||
|
|
@ -230,12 +229,6 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
)}
|
||||
</>
|
||||
{/* )} */}
|
||||
|
||||
{/* {selectedTab === "feedback" && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<FeedbackForm onSuccess={() => {}} />
|
||||
</div>
|
||||
)} */}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useRef, useState, MutableRefObject } from "react";
|
||||
import { useCallback, useRef, useState, RefObject } from "react";
|
||||
|
||||
import Link from "next/link";
|
||||
import { TextLogo } from "@/ui/App";
|
||||
|
|
@ -47,11 +47,11 @@ export default function GraphView() {
|
|||
updateData(newData);
|
||||
}, []);
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
|
||||
const graphControls = useRef<GraphControlsAPI>();
|
||||
const graphControls = useRef<GraphControlsAPI>(null);
|
||||
|
||||
const activityLog = useRef<ActivityLogAPI>();
|
||||
const activityLog = useRef<ActivityLogAPI>(null);
|
||||
|
||||
return (
|
||||
<main className="flex flex-col h-full">
|
||||
|
|
@ -74,21 +74,18 @@ export default function GraphView() {
|
|||
<div className="w-full h-full relative overflow-hidden">
|
||||
<GraphVisualization
|
||||
key={data?.nodes.length}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
data={data}
|
||||
graphControls={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
graphControls={graphControls as RefObject<GraphControlsAPI>}
|
||||
/>
|
||||
|
||||
<div className="absolute top-2 left-2 flex flex-col gap-2">
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CogneeAddWidget onData={onDataChange} />
|
||||
</div>
|
||||
{/* <div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CrewAITrigger onData={onDataChange} onActivity={(activities) => activityLog.current?.updateActivityLog(activities)} />
|
||||
</div> */}
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<h2 className="text-xl text-white mb-4">Activity Log</h2>
|
||||
<ActivityLog ref={activityLog as MutableRefObject<ActivityLogAPI>} />
|
||||
<ActivityLog ref={activityLog as RefObject<ActivityLogAPI>} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -96,7 +93,7 @@ export default function GraphView() {
|
|||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-110">
|
||||
<GraphControls
|
||||
data={data}
|
||||
ref={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
ref={graphControls as RefObject<GraphControlsAPI>}
|
||||
isAddNodeFormOpen={isAddNodeFormOpen}
|
||||
onFitIntoView={() => graphRef.current!.zoomToFit(1000, 50)}
|
||||
onGraphShapeChange={(shape) => graphRef.current!.setGraphShape(shape)}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import classNames from "classnames";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { RefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import dynamic from "next/dynamic";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
|
|
@ -16,9 +16,9 @@ const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
|||
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
interface GraphVisuzaliationProps {
|
||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||
ref: RefObject<GraphVisualizationAPI>;
|
||||
data?: GraphData<NodeObject, LinkObject>;
|
||||
graphControls: MutableRefObject<GraphControlsAPI>;
|
||||
graphControls: RefObject<GraphControlsAPI>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function handleDagError(loopNodeIds: (string | number)[]) {}
|
||||
|
||||
const graphRef = useRef<ForceGraphMethods>();
|
||||
const graphRef = useRef<ForceGraphMethods>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (data && graphRef.current) {
|
||||
|
|
@ -224,6 +224,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
) => {
|
||||
if (!graphRef.current) {
|
||||
console.warn("GraphVisualization: graphRef not ready yet");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return undefined as any;
|
||||
}
|
||||
|
||||
|
|
@ -239,7 +240,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
ref={graphRef as RefObject<ForceGraphMethods>}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"use server";
|
||||
"use client";
|
||||
|
||||
import Dashboard from "./Dashboard";
|
||||
|
||||
export default async function Page() {
|
||||
export default function Page() {
|
||||
const accessToken = "";
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
export { default } from "./dashboard/page";
|
||||
|
||||
// export const dynamic = "force-dynamic";
|
||||
export const dynamic = "force-dynamic";
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ export interface Dataset {
|
|||
|
||||
function useDatasets(useCloud = false) {
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
// const statusTimeout = useRef<any>(null);
|
||||
|
||||
// const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { NextResponse, type NextRequest } from "next/server";
|
|||
// import { auth0 } from "./modules/auth/auth0";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export async function middleware(request: NextRequest) {
|
||||
export async function proxy(request: NextRequest) {
|
||||
// if (process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "true") {
|
||||
// if (request.nextUrl.pathname === "/auth/token") {
|
||||
// return NextResponse.next();
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { fetch, useBoolean } from "@/utils";
|
||||
import { CTAButton, TextArea } from "@/ui/elements";
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
feedback: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
interface FeedbackFormProps {
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export default function FeedbackForm({ onSuccess }: FeedbackFormProps) {
|
||||
const {
|
||||
value: isSubmittingFeedback,
|
||||
setTrue: disableFeedbackSubmit,
|
||||
setFalse: enableFeedbackSubmit,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [feedbackError, setFeedbackError] = useState<string | null>(null);
|
||||
|
||||
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
setFeedbackError(null);
|
||||
disableFeedbackSubmit();
|
||||
|
||||
fetch("/v1/crewai/feedback", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
feedback: formElements.feedback.value,
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onSuccess();
|
||||
formElements.feedback.value = "";
|
||||
})
|
||||
.catch(error => setFeedbackError(error.detail))
|
||||
.finally(() => enableFeedbackSubmit());
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="mb-4">
|
||||
<label className="block text-white" htmlFor="feedback">Feedback on agent's reasoning</label>
|
||||
<TextArea id="feedback" name="feedback" type="text" placeholder="Your feedback" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<CTAButton type="submit">
|
||||
<span>Submit feedback</span>
|
||||
{isSubmittingFeedback && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
|
||||
{feedbackError && (
|
||||
<span className="text-s text-white">{feedbackError}</span>
|
||||
)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
@ -3,4 +3,3 @@ export { default as Footer } from "./Footer/Footer";
|
|||
export { default as SearchView } from "./SearchView/SearchView";
|
||||
export { default as IFrameView } from "./IFrameView/IFrameView";
|
||||
// export { default as Explorer } from "./Explorer/Explorer";
|
||||
export { default as FeedbackForm } from "./FeedbackForm";
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import classNames from "classnames";
|
||||
import { Fragment, MouseEvent, MutableRefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Fragment, MouseEvent, RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import { useModal } from "@/ui/elements/Modal";
|
||||
import { CaretIcon, CloseIcon, PlusIcon } from "@/ui/Icons";
|
||||
|
|
@ -282,7 +282,7 @@ export default function Notebook({ notebook, updateNotebook, runCell }: Notebook
|
|||
function CellResult({ content }: { content: [] }) {
|
||||
const parsedContent = [];
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
const graphControls = useRef<GraphControlsAPI>({
|
||||
setSelectedNode: () => {},
|
||||
getSelectedNode: () => null,
|
||||
|
|
@ -298,7 +298,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<GraphVisualization
|
||||
data={transformInsightsGraphData(line)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -346,7 +346,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -377,7 +377,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ export default function NotebookCellHeader({
|
|||
setFalse: setIsNotRunningCell,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [runInstance, setRunInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
const [runInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
|
||||
const handleCellRun = () => {
|
||||
if (runCell) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
"jsx": "react-jsx",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
|
|
@ -32,7 +32,8 @@
|
|||
"next-env.d.ts",
|
||||
"**/*.ts",
|
||||
"**/*.tsx",
|
||||
".next/types/**/*.ts"
|
||||
".next/types/**/*.ts",
|
||||
".next/dev/types/**/*.ts"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
|
|
|
|||
|
|
@ -445,16 +445,22 @@ The MCP server exposes its functionality through tools. Call them from any MCP c
|
|||
|
||||
- **cognify**: Turns your data into a structured knowledge graph and stores it in memory
|
||||
|
||||
- **cognee_add_developer_rules**: Ingest core developer rule files into memory
|
||||
|
||||
- **codify**: Analyse a code repository, build a code graph, stores it in memory
|
||||
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS
|
||||
|
||||
- **list_data**: List all datasets and their data items with IDs for deletion operations
|
||||
|
||||
- **delete**: Delete specific data from a dataset (supports soft/hard deletion modes)
|
||||
|
||||
- **get_developer_rules**: Retrieve all developer rules that were generated based on previous interactions
|
||||
|
||||
- **list_data**: List all datasets and their data items with IDs for deletion operations
|
||||
|
||||
- **save_interaction**: Logs user-agent interactions and query-answer pairs
|
||||
|
||||
- **prune**: Reset cognee for a fresh start (removes all data)
|
||||
|
||||
- **search**: Query memory – supports GRAPH_COMPLETION, RAG_COMPLETION, CODE, CHUNKS, SUMMARIES, CYPHER, and FEELING_LUCKY
|
||||
|
||||
- **cognify_status / codify_status**: Track pipeline progress
|
||||
|
||||
**Data Management Examples:**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "cognee-mcp"
|
||||
version = "0.4.0"
|
||||
version = "0.5.0"
|
||||
description = "Cognee MCP server"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
@ -9,7 +9,7 @@ dependencies = [
|
|||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,docs,neo4j]==0.3.7",
|
||||
"cognee[postgres,docs,neo4j]==0.5.0",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class CogneeClient:
|
|||
query_type: str,
|
||||
datasets: Optional[List[str]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
top_k: int = 5,
|
||||
) -> Any:
|
||||
"""
|
||||
Search the knowledge graph.
|
||||
|
|
@ -192,7 +192,7 @@ class CogneeClient:
|
|||
|
||||
with redirect_stdout(sys.stderr):
|
||||
results = await self.cognee.search(
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text, top_k=top_k
|
||||
)
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -90,97 +90,6 @@ async def health_check(request):
|
|||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognee_add_developer_rules(
|
||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Ingest core developer rule files into Cognee's memory layer.
|
||||
|
||||
This function loads a predefined set of developer-related configuration,
|
||||
rule, and documentation files from the base repository and assigns them
|
||||
to the special 'developer_rules' node set in Cognee. It ensures these
|
||||
foundational files are always part of the structured memory graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_path : str
|
||||
Root path to resolve relative file paths. Defaults to current directory.
|
||||
|
||||
graph_model_file : str, optional
|
||||
Optional path to a custom schema file for knowledge graph generation.
|
||||
|
||||
graph_model_name : str, optional
|
||||
Optional class name to use from the graph_model_file schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A message indicating how many rule files were scheduled for ingestion,
|
||||
and how to check their processing status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Each file is processed asynchronously in the background.
|
||||
- Files are attached to the 'developer_rules' node set.
|
||||
- Missing files are skipped with a logged warning.
|
||||
"""
|
||||
|
||||
developer_rule_paths = [
|
||||
".cursorrules",
|
||||
".cursor/rules",
|
||||
".same/todos.md",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
"CLAUDE.md",
|
||||
".sourcegraph/memory.md",
|
||||
"AGENT.md",
|
||||
"AGENTS.md",
|
||||
]
|
||||
|
||||
async def cognify_task(file_path: str) -> None:
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
tasks = []
|
||||
for rel_path in developer_rule_paths:
|
||||
abs_path = os.path.join(base_path, rel_path)
|
||||
if os.path.isfile(abs_path):
|
||||
tasks.append(asyncio.create_task(cognify_task(abs_path)))
|
||||
else:
|
||||
logger.warning(f"Skipped missing developer rule file: {abs_path}")
|
||||
log_file = get_log_file_location()
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Started cognify for {len(tasks)} developer rule files in background.\n"
|
||||
f"All are added to the `developer_rules` node set.\n"
|
||||
f"Use `cognify_status` or check logs at {log_file} to monitor progress."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
|
|
@ -407,76 +316,7 @@ async def save_interaction(data: str) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify(repo_path: str) -> list:
|
||||
"""
|
||||
Analyze and generate a code-specific knowledge graph from a software repository.
|
||||
|
||||
This function launches a background task that processes the provided repository
|
||||
and builds a code knowledge graph. The function returns immediately while
|
||||
the processing continues in the background due to MCP timeout constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_path : str
|
||||
Path to the code repository to analyze. This can be a local file path or a
|
||||
relative path to a repository. The path should point to the root of the
|
||||
repository or a specific directory within it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with information about the
|
||||
background task launch and how to check its status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function launches a background task and returns immediately
|
||||
- The code graph generation may take significant time for larger repositories
|
||||
- Use the codify_status tool to check the progress of the operation
|
||||
- Process results are logged to the standard Cognee log file
|
||||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
asyncio.create_task(codify_task(repo_path))
|
||||
|
||||
log_file = get_log_file_location()
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current codify status use the codify_status tool\n"
|
||||
f"or you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -549,6 +389,13 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
|
||||
The search_type is case-insensitive and will be converted to uppercase.
|
||||
|
||||
top_k : int, optional
|
||||
Maximum number of results to return (default: 10).
|
||||
Controls the amount of context retrieved from the knowledge graph.
|
||||
- Lower values (3-5): Faster, more focused results
|
||||
- Higher values (10-20): More comprehensive, but slower and more context-heavy
|
||||
Helps manage response size and context window usage in MCP clients.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
|
|
@ -585,13 +432,32 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
|
||||
"""
|
||||
|
||||
async def search_task(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
async def search_task(search_query: str, search_type: str, top_k: int) -> str:
|
||||
"""
|
||||
Internal task to execute knowledge graph search with result formatting.
|
||||
|
||||
Handles the actual search execution and formats results appropriately
|
||||
for MCP clients based on the search type and execution mode (API vs direct).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
search_query : str
|
||||
The search query in natural language
|
||||
search_type : str
|
||||
Type of search to perform (GRAPH_COMPLETION, CHUNKS, etc.)
|
||||
top_k : int
|
||||
Maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Formatted search results as a string, with format depending on search_type
|
||||
"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee_client.search(
|
||||
query_text=search_query, query_type=search_type
|
||||
query_text=search_query, query_type=search_type, top_k=top_k
|
||||
)
|
||||
|
||||
# Handle different result formats based on API vs direct mode
|
||||
|
|
@ -625,49 +491,10 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
else:
|
||||
return str(search_results)
|
||||
|
||||
search_results = await search_task(search_query, search_type)
|
||||
search_results = await search_task(search_query, search_type, top_k)
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_developer_rules() -> list:
|
||||
"""
|
||||
Retrieve all developer rules that were generated based on previous interactions.
|
||||
|
||||
This tool queries the Cognee knowledge graph and returns a list of developer
|
||||
rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the retrieved developer rules.
|
||||
The format is plain text containing the developer rules in bulletpoints.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The specific logic for fetching rules is handled internally.
|
||||
- This tool does not accept any parameters and is intended for simple rule inspection use cases.
|
||||
"""
|
||||
|
||||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
rules_text = await fetch_rules_from_cognee()
|
||||
|
||||
return [types.TextContent(type="text", text=rules_text)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_data(dataset_id: str = None) -> list:
|
||||
"""
|
||||
|
|
@ -953,48 +780,6 @@ async def cognify_status():
|
|||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify_status():
|
||||
"""
|
||||
Get the current status of the codify pipeline.
|
||||
|
||||
This function retrieves information about current and recently completed codify operations
|
||||
in the codebase dataset. It provides details on progress, success/failure status, and statistics
|
||||
about the processed code repositories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the status information as a string.
|
||||
The status includes information about active and completed jobs for the cognify_code_pipeline.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Test client for Cognee MCP Server functionality.
|
||||
|
||||
This script tests all the tools and functions available in the Cognee MCP server,
|
||||
including cognify, codify, search, prune, status checks, and utility functions.
|
||||
including cognify, search, prune, status checks, and utility functions.
|
||||
|
||||
Usage:
|
||||
# Set your OpenAI API key first
|
||||
|
|
@ -23,6 +23,7 @@ import tempfile
|
|||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.shared.logging_utils import setup_logging
|
||||
from logging import ERROR, INFO
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
|
@ -35,7 +36,7 @@ from src.server import (
|
|||
load_class,
|
||||
)
|
||||
|
||||
# Set timeout for cognify/codify to complete in
|
||||
# Set timeout for cognify to complete in
|
||||
TIMEOUT = 5 * 60 # 5 min in seconds
|
||||
|
||||
|
||||
|
|
@ -151,12 +152,9 @@ DEBUG = True
|
|||
|
||||
expected_tools = {
|
||||
"cognify",
|
||||
"codify",
|
||||
"search",
|
||||
"prune",
|
||||
"cognify_status",
|
||||
"codify_status",
|
||||
"cognee_add_developer_rules",
|
||||
"list_data",
|
||||
"delete",
|
||||
}
|
||||
|
|
@ -247,106 +245,6 @@ DEBUG = True
|
|||
}
|
||||
print(f"❌ {test_name} test failed: {e}")
|
||||
|
||||
async def test_codify(self):
|
||||
"""Test the codify functionality using MCP client."""
|
||||
print("\n🧪 Testing codify functionality...")
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
codify_result = await session.call_tool(
|
||||
"codify", arguments={"repo_path": self.test_repo_dir}
|
||||
)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if codify processing is finished
|
||||
status_result = await session.call_tool("codify_status", arguments={})
|
||||
if hasattr(status_result, "content") and status_result.content:
|
||||
status_text = (
|
||||
status_result.content[0].text
|
||||
if status_result.content
|
||||
else str(status_result)
|
||||
)
|
||||
else:
|
||||
status_text = str(status_result)
|
||||
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Codify did not complete in 5min")
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["codify"] = {
|
||||
"status": "PASS",
|
||||
"result": codify_result,
|
||||
"message": "Codify executed successfully",
|
||||
}
|
||||
print("✅ Codify test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["codify"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Codify test failed",
|
||||
}
|
||||
print(f"❌ Codify test failed: {e}")
|
||||
|
||||
async def test_cognee_add_developer_rules(self):
|
||||
"""Test the cognee_add_developer_rules functionality using MCP client."""
|
||||
print("\n🧪 Testing cognee_add_developer_rules functionality...")
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
result = await session.call_tool(
|
||||
"cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
|
||||
)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if developer rule cognify processing is finished
|
||||
status_result = await session.call_tool("cognify_status", arguments={})
|
||||
if hasattr(status_result, "content") and status_result.content:
|
||||
status_text = (
|
||||
status_result.content[0].text
|
||||
if status_result.content
|
||||
else str(status_result)
|
||||
)
|
||||
else:
|
||||
status_text = str(status_result)
|
||||
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError(
|
||||
"Cognify of developer rules did not complete in 5min"
|
||||
)
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "Developer rules addition executed successfully",
|
||||
}
|
||||
print("✅ Developer rules test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Developer rules test failed",
|
||||
}
|
||||
print(f"❌ Developer rules test failed: {e}")
|
||||
|
||||
async def test_search_functionality(self):
|
||||
"""Test the search functionality with different search types using MCP client."""
|
||||
print("\n🧪 Testing search functionality...")
|
||||
|
|
@ -359,7 +257,11 @@ DEBUG = True
|
|||
# Go through all Cognee search types
|
||||
for search_type in SearchType:
|
||||
# Don't test these search types
|
||||
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
|
||||
if search_type in [
|
||||
SearchType.NATURAL_LANGUAGE,
|
||||
SearchType.CYPHER,
|
||||
SearchType.TRIPLET_COMPLETION,
|
||||
]:
|
||||
break
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
|
|
@ -681,9 +583,6 @@ class TestModel:
|
|||
test_name="Cognify2",
|
||||
)
|
||||
|
||||
await self.test_codify()
|
||||
await self.test_cognee_add_developer_rules()
|
||||
|
||||
# Test list_data and delete functionality
|
||||
await self.test_list_data()
|
||||
await self.test_delete()
|
||||
|
|
@ -739,7 +638,5 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from logging import ERROR
|
||||
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
7633
cognee-mcp/uv.lock
generated
7633
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -21,7 +21,7 @@ from cognee.api.v1.notebooks.routers import get_notebooks_router
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
|
|
@ -278,10 +278,6 @@ app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["re
|
|||
|
||||
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix="/api/v1/users",
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
|
|
@ -205,6 +205,7 @@ async def add(
|
|||
pipeline_name="add_pipeline",
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
use_pipeline_cache=True,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
|
||||
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
logger = get_logger("code_graph_pipeline")
|
||||
|
||||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(
|
||||
repo_path,
|
||||
include_docs=False,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
supported_languages: Optional[list[str]] = None,
|
||||
):
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
cognee_config = get_cognify_config()
|
||||
user = await get_default_user()
|
||||
detailed_extraction = True
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
get_repo_file_dependencies,
|
||||
detailed_extraction=detailed_extraction,
|
||||
supported_languages=supported_languages,
|
||||
excluded_paths=excluded_paths,
|
||||
),
|
||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config={"batch_size": 30}),
|
||||
]
|
||||
|
||||
if include_docs:
|
||||
# This tasks take a long time to complete
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
]
|
||||
|
||||
dataset_name = "codebase"
|
||||
|
||||
# Save dataset to database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
||||
)
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
search_results = await search(
|
||||
query_type=SearchType.CODE,
|
||||
query_text="How is Relationship weight calculated?",
|
||||
)
|
||||
|
||||
for file in search_results:
|
||||
print(file["name"])
|
||||
|
||||
logger = setup_logging(name="code_graph_pipeline")
|
||||
asyncio.run(main())
|
||||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|||
from typing import Union, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
|
@ -19,7 +20,6 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -53,6 +53,7 @@ async def cognify(
|
|||
custom_prompt: Optional[str] = None,
|
||||
temporal_cognify: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
|
@ -78,12 +79,11 @@ async def cognify(
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
2. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
3. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
4. **Relationship Detection**: Discovers connections between entities
|
||||
5. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
6. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
|
||||
Graph Model Customization:
|
||||
The `graph_model` parameter allows custom knowledge structures:
|
||||
|
|
@ -224,6 +224,7 @@ async def cognify(
|
|||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
chunks_per_batch=chunks_per_batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
|
|
@ -238,6 +239,7 @@ async def cognify(
|
|||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
use_pipeline_cache=True,
|
||||
pipeline_name="cognify_pipeline",
|
||||
data_per_batch=data_per_batch,
|
||||
)
|
||||
|
|
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
config: Config = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
chunks_per_batch: int = 100,
|
||||
**kwargs,
|
||||
) -> list[Task]:
|
||||
if config is None:
|
||||
ontology_config = get_ontology_env_config()
|
||||
|
|
@ -272,9 +275,11 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
if chunks_per_batch is None:
|
||||
chunks_per_batch = 100
|
||||
|
||||
cognify_config = get_cognify_config()
|
||||
embed_triplets = cognify_config.triplet_embedding
|
||||
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
@ -286,12 +291,17 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
**kwargs,
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
|
||||
Task(
|
||||
add_data_points,
|
||||
embed_triplets=embed_triplets,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
@ -305,14 +315,13 @@ async def get_temporal_tasks(
|
|||
|
||||
The pipeline includes:
|
||||
1. Document classification.
|
||||
2. Dataset permission checks (requires "write" access).
|
||||
3. Document chunking with a specified or default chunk size.
|
||||
4. Event and timestamp extraction from chunks.
|
||||
5. Knowledge graph extraction from events.
|
||||
6. Batched insertion of data points.
|
||||
2. Document chunking with a specified or default chunk size.
|
||||
3. Event and timestamp extraction from chunks.
|
||||
4. Knowledge graph extraction from events.
|
||||
5. Batched insertion of data points.
|
||||
|
||||
Args:
|
||||
user (User, optional): The user requesting task execution, used for permission checks.
|
||||
user (User, optional): The user requesting task execution.
|
||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
||||
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
||||
|
|
@ -325,7 +334,6 @@ async def get_temporal_tasks(
|
|||
|
||||
temporal_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
|
|
@ -1,90 +0,0 @@
|
|||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(InDTO):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(InDTO):
|
||||
query: str
|
||||
full_input: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
try:
|
||||
import cognee.api.v1.cognify.code_graph_pipeline
|
||||
except ModuleNotFoundError:
|
||||
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
|
||||
return None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""
|
||||
Run indexation on a code repository.
|
||||
|
||||
This endpoint processes a code repository to create a knowledge graph
|
||||
of the codebase structure, dependencies, and relationships.
|
||||
|
||||
## Request Parameters
|
||||
- **repo_path** (str): Path to the code repository
|
||||
- **include_docs** (bool): Whether to include documentation files (default: false)
|
||||
|
||||
## Response
|
||||
No content returned. Processing results are logged.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during indexation process
|
||||
"""
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
logger.info(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""
|
||||
Retrieve context from the code knowledge graph.
|
||||
|
||||
This endpoint searches the indexed code repository to find relevant
|
||||
context based on the provided query.
|
||||
|
||||
## Request Parameters
|
||||
- **query** (str): Search query for code context
|
||||
- **full_input** (str): Full input text for processing
|
||||
|
||||
## Response
|
||||
Returns a list of relevant code files and context as JSON.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during retrieval process
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
payload.full_input.replace("cognee ", "")
|
||||
if payload.full_input.startswith("cognee ")
|
||||
else payload.full_input
|
||||
)
|
||||
|
||||
retriever = CodeRetriever()
|
||||
retrieved_files = await retriever.get_context(query)
|
||||
|
||||
return json.dumps(retrieved_files, cls=JSONEncoder)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO):
|
|||
default="", description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
ontology_key: Optional[List[str]] = Field(
|
||||
default=None, description="Reference to one or more previously uploaded ontologies"
|
||||
default=None,
|
||||
examples=[[]],
|
||||
description="Reference to one or more previously uploaded ontologies",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
|
|||
},
|
||||
)
|
||||
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
from cognee.modules.data.methods import delete_dataset
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
|
||||
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
|
||||
|
||||
await delete_dataset(dataset)
|
||||
await delete_dataset(dataset[0])
|
||||
|
||||
@router.delete(
|
||||
"/{dataset_id}/data/{data_id}",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -45,8 +46,10 @@ class OntologyService:
|
|||
json.dump(metadata, f, indent=2)
|
||||
|
||||
async def upload_ontology(
|
||||
self, ontology_key: str, file, user, description: Optional[str] = None
|
||||
self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None
|
||||
) -> OntologyMetadata:
|
||||
if not file.filename:
|
||||
raise ValueError("File must have a filename")
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError("File must be in .owl format")
|
||||
|
||||
|
|
@ -57,8 +60,6 @@ class OntologyService:
|
|||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError("File size exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{ontology_key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
|
|
@ -82,7 +83,11 @@ class OntologyService:
|
|||
)
|
||||
|
||||
async def upload_ontologies(
|
||||
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
|
||||
self,
|
||||
ontology_key: List[str],
|
||||
files: List[UploadFile],
|
||||
user,
|
||||
descriptions: Optional[List[str]] = None,
|
||||
) -> List[OntologyMetadata]:
|
||||
"""
|
||||
Upload ontology files with their respective keys.
|
||||
|
|
@ -105,47 +110,17 @@ class OntologyService:
|
|||
if len(set(ontology_key)) != len(ontology_key):
|
||||
raise ValueError("Duplicate ontology keys not allowed")
|
||||
|
||||
if descriptions and len(descriptions) != len(files):
|
||||
raise ValueError("Number of descriptions must match number of files")
|
||||
|
||||
results = []
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||
if key in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' already exists")
|
||||
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError(f"File '{file.filename}' must be in .owl format")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": descriptions[i] if descriptions else None,
|
||||
}
|
||||
metadata[key] = ontology_metadata
|
||||
|
||||
results.append(
|
||||
OntologyMetadata(
|
||||
await self.upload_ontology(
|
||||
ontology_key=key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
file=file,
|
||||
user=user,
|
||||
description=descriptions[i] if descriptions else None,
|
||||
)
|
||||
)
|
||||
|
||||
self._save_metadata(user_dir, metadata)
|
||||
return results
|
||||
|
||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
|
||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, List
|
||||
|
||||
|
|
@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter:
|
|||
|
||||
@router.post("", response_model=dict)
|
||||
async def upload_ontology(
|
||||
request: Request,
|
||||
ontology_key: str = Form(...),
|
||||
ontology_file: List[UploadFile] = File(...),
|
||||
descriptions: Optional[str] = Form(None),
|
||||
ontology_file: UploadFile = File(...),
|
||||
description: Optional[str] = Form(None),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Upload ontology files with their respective keys for later use in cognify operations.
|
||||
|
||||
Supports both single and multiple file uploads:
|
||||
- Single file: ontology_key=["key"], ontology_file=[file]
|
||||
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
|
||||
Upload a single ontology file for later use in cognify operations.
|
||||
|
||||
## Request Parameters
|
||||
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
|
||||
- **ontology_file** (List[UploadFile]): OWL format ontology files
|
||||
- **descriptions** (Optional[str]): JSON array string of optional descriptions
|
||||
- **ontology_key** (str): User-defined identifier for the ontology.
|
||||
- **ontology_file** (UploadFile): Single OWL format ontology file
|
||||
- **description** (Optional[str]): Optional description for the ontology.
|
||||
|
||||
## Response
|
||||
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
|
||||
Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
|
||||
- **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
|
|
@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter:
|
|||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
# Enforce: exactly one uploaded file for "ontology_file"
|
||||
form = await request.form()
|
||||
uploaded_files = form.getlist("ontology_file")
|
||||
if len(uploaded_files) != 1:
|
||||
raise ValueError("Only one ontology_file is allowed")
|
||||
|
||||
ontology_keys = json.loads(ontology_key)
|
||||
description_list = json.loads(descriptions) if descriptions else None
|
||||
if ontology_key.strip().startswith(("[", "{")):
|
||||
raise ValueError("ontology_key must be a string")
|
||||
if description is not None and description.strip().startswith(("[", "{")):
|
||||
raise ValueError("description must be a string")
|
||||
|
||||
if not isinstance(ontology_keys, list):
|
||||
raise ValueError("ontology_key must be a JSON array")
|
||||
|
||||
results = await ontology_service.upload_ontologies(
|
||||
ontology_keys, ontology_file, user, description_list
|
||||
result = await ontology_service.upload_ontology(
|
||||
ontology_key=ontology_key,
|
||||
file=ontology_file,
|
||||
user=user,
|
||||
description=description,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter:
|
|||
"uploaded_at": result.uploaded_at,
|
||||
"description": result.description,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
except ValueError as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ async def search(
|
|||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
verbose: bool = False,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -123,6 +124,8 @@ async def search(
|
|||
|
||||
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
|
||||
|
||||
verbose: If True, returns detailed result information including graph representation (when possible).
|
||||
|
||||
Returns:
|
||||
list: Search results in format determined by query_type:
|
||||
|
||||
|
|
@ -204,6 +207,7 @@ async def search(
|
|||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
360
cognee/api/v1/ui/node_setup.py
Normal file
360
cognee/api/v1/ui/node_setup.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_nvm_dir() -> Path:
|
||||
"""
|
||||
Get the nvm directory path following standard nvm installation logic.
|
||||
Uses XDG_CONFIG_HOME if set, otherwise falls back to ~/.nvm.
|
||||
"""
|
||||
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
||||
if xdg_config_home:
|
||||
return Path(xdg_config_home) / "nvm"
|
||||
return Path.home() / ".nvm"
|
||||
|
||||
|
||||
def get_nvm_sh_path() -> Path:
|
||||
"""
|
||||
Get the path to nvm.sh following standard nvm installation logic.
|
||||
"""
|
||||
return get_nvm_dir() / "nvm.sh"
|
||||
|
||||
|
||||
def check_nvm_installed() -> bool:
|
||||
"""
|
||||
Check if nvm (Node Version Manager) is installed.
|
||||
"""
|
||||
try:
|
||||
# Check if nvm is available in the shell
|
||||
# nvm is typically sourced in shell config files, so we need to check via shell
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, nvm-windows uses a different approach
|
||||
result = subprocess.run(
|
||||
["nvm", "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, nvm is a shell function, so we need to source it
|
||||
# First check if nvm.sh exists
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.debug(f"nvm.sh not found at {nvm_path}")
|
||||
return False
|
||||
|
||||
# Try to source nvm and check version, capturing errors
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && nvm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Log the error to help diagnose configuration issues
|
||||
if result.stderr:
|
||||
logger.debug(f"nvm check failed: {result.stderr.strip()}")
|
||||
return False
|
||||
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception checking nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_nvm() -> bool:
|
||||
"""
|
||||
Install nvm (Node Version Manager) on Unix-like systems.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("nvm installation on Windows requires nvm-windows.")
|
||||
logger.error(
|
||||
"Please install nvm-windows manually from: https://github.com/coreybutler/nvm-windows"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Installing nvm (Node Version Manager)...")
|
||||
|
||||
try:
|
||||
# Download and install nvm
|
||||
nvm_install_script = "https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh"
|
||||
logger.info(f"Downloading nvm installer from {nvm_install_script}...")
|
||||
|
||||
response = requests.get(nvm_install_script, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create a temporary script file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
|
||||
f.write(response.text)
|
||||
install_script_path = f.name
|
||||
|
||||
try:
|
||||
# Make the script executable and run it
|
||||
os.chmod(install_script_path, 0o755)
|
||||
result = subprocess.run(
|
||||
["bash", install_script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ nvm installed successfully")
|
||||
# Source nvm in current shell session
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"nvm installation completed but nvm directory not found at {nvm_dir}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.error(f"nvm installation failed: {result.stderr}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up temporary script
|
||||
try:
|
||||
os.unlink(install_script_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to download nvm installer: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_node_with_nvm() -> bool:
|
||||
"""
|
||||
Install the latest Node.js version using nvm.
|
||||
Returns True if installation succeeds, False otherwise.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("Node.js installation via nvm on Windows requires nvm-windows.")
|
||||
logger.error("Please install Node.js manually from: https://nodejs.org/")
|
||||
return False
|
||||
|
||||
logger.info("Installing latest Node.js version using nvm...")
|
||||
|
||||
try:
|
||||
# Source nvm and install latest Node.js
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.error(f"nvm.sh not found at {nvm_path}. nvm may not be properly installed.")
|
||||
return False
|
||||
|
||||
nvm_source_cmd = f"source {nvm_path}"
|
||||
install_cmd = f"{nvm_source_cmd} && nvm install node"
|
||||
|
||||
result = subprocess.run(
|
||||
["bash", "-c", install_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout for Node.js installation
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ Node.js installed successfully via nvm")
|
||||
|
||||
# Set as default version
|
||||
use_cmd = f"{nvm_source_cmd} && nvm alias default node"
|
||||
subprocess.run(
|
||||
["bash", "-c", use_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Add nvm to PATH for current session
|
||||
# This ensures node/npm are available in subsequent commands
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
# Update PATH for current process
|
||||
nvm_bin = nvm_dir / "versions" / "node"
|
||||
# Find the latest installed version
|
||||
if nvm_bin.exists():
|
||||
versions = sorted(nvm_bin.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = os.environ.get("PATH", "")
|
||||
os.environ["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to install Node.js: {result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Timeout installing Node.js (this can take several minutes)")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing Node.js: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]: # (is_available, error_message)
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
If not available, attempts to install nvm and Node.js automatically.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js - try direct command first, then with nvm if needed
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
# If direct command fails, try with nvm sourced (in case nvm is installed but not in PATH)
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run node: {result.stderr.strip()}")
|
||||
if result.returncode != 0:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not installed. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Verify installation after automatic setup
|
||||
# Try with nvm sourced first
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(
|
||||
f"Failed to verify node after installation: {result.stderr.strip()}"
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
return (
|
||||
False,
|
||||
f"Node.js installation completed but node command is not available. Please restart your terminal or source {nvm_path}",
|
||||
)
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, if we just installed via nvm, we may need to source nvm
|
||||
# Try direct command first
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
# Try with nvm sourced
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not found. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Retry checking Node.js after installation
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
node_version = result.stdout.strip()
|
||||
# Check npm
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
npm_version = result.stdout.strip()
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
elif result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception retrying node/npm check: {str(e)}")
|
||||
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
50
cognee/api/v1/ui/npm_utils.py
Normal file
50
cognee/api/v1/ui/npm_utils.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from .node_setup import get_nvm_sh_path
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def run_npm_command(cmd: List[str], cwd: Path, timeout: int = 300) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Run an npm command, ensuring nvm is sourced if needed (Unix-like systems only).
|
||||
Returns the CompletedProcess result.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, use shell=True for npm commands
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, try direct command first
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
# If it fails and nvm might be installed, try with nvm sourced
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
nvm_cmd = f"source {nvm_path} && {' '.join(cmd)}"
|
||||
result = subprocess.run(
|
||||
["bash", "-c", nvm_cmd],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"npm command failed with nvm: {result.stderr.strip()}")
|
||||
return result
|
||||
|
|
@ -15,6 +15,8 @@ import shutil
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.version import get_cognee_version
|
||||
from .node_setup import check_node_npm, get_nvm_dir, get_nvm_sh_path
|
||||
from .npm_utils import run_npm_command
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -285,48 +287,6 @@ def find_frontend_path() -> Optional[Path]:
|
|||
return None
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]:
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
Returns (is_available, error_message)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return False, "Node.js is not installed or not in PATH"
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
|
||||
|
||||
def install_frontend_dependencies(frontend_path: Path) -> bool:
|
||||
"""
|
||||
Install frontend dependencies if node_modules doesn't exist.
|
||||
|
|
@ -341,24 +301,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||
|
||||
try:
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
result = run_npm_command(["npm", "install"], frontend_path, timeout=300)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Frontend dependencies installed successfully")
|
||||
|
|
@ -642,6 +585,21 @@ def start_ui(
|
|||
env["HOST"] = "localhost"
|
||||
env["PORT"] = str(port)
|
||||
|
||||
# If nvm is installed, ensure it's available in the environment
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if platform.system() != "Windows" and nvm_path.exists():
|
||||
# Add nvm to PATH for the subprocess
|
||||
nvm_dir = get_nvm_dir()
|
||||
# Find the latest Node.js version installed via nvm
|
||||
nvm_versions = nvm_dir / "versions" / "node"
|
||||
if nvm_versions.exists():
|
||||
versions = sorted(nvm_versions.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = env.get("PATH", "")
|
||||
env["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
# Start the development server
|
||||
logger.info(f"Starting frontend server at http://localhost:{port}")
|
||||
logger.info("This may take a moment to compile and start...")
|
||||
|
|
@ -659,14 +617,26 @@ def start_ui(
|
|||
shell=True,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
# On Unix-like systems, use bash with nvm sourced if available
|
||||
if nvm_path.exists():
|
||||
# Use bash to source nvm and run npm
|
||||
process = subprocess.Popen(
|
||||
["bash", "-c", f"source {nvm_path} && npm run dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream frontend output with prefix
|
||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
|
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
graph_db_config = get_graph_config()
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
||||
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
if graph_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if vector_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||
!= graph_db_config.graph_database_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||
!= vector_db_config.vector_db_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
|
|
@ -41,12 +79,7 @@ def backend_access_control_enabled():
|
|||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return multi_user_support_possible()
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
# Ensure that all connection info is resolved properly
|
||||
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
|
|
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||
vector_config = {
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
|
|
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
"graph_dataset_database_handler": "",
|
||||
"graph_database_port": "",
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from cognee.modules.users.models import User
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -31,7 +30,6 @@ async def get_cascade_graph_tasks(
|
|||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
), # Extract text chunks based on the document type.
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ async def get_no_summary_tasks(
|
|||
ontology_file_path=None,
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks without summarization tasks."""
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
|
||||
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||
|
||||
|
|
@ -51,8 +51,8 @@ async def get_just_chunks_tasks(
|
|||
chunk_size: int = None, chunker=TextChunker, user=None
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks with only chunk extraction and data points addition."""
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
|
||||
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
from .use_dataset_database_handler import use_dataset_database_handler
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph or vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||
|
||||
This method is invoked right before the application opens a connection for a given dataset.
|
||||
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||
|
||||
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||
|
||||
Typical behavior:
|
||||
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||
or api_url/api_key), return them as-is.
|
||||
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||
to the caller.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row from the relational database
|
||||
Returns:
|
||||
DatasetDatabase: Updated instance with resolved connection info
|
||||
"""
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||
"""
|
||||
Delete the graph or vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||
Neo4jAuraDevDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||
LanceDBDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura_dev": {
|
||||
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||
"handler_provider": "neo4j",
|
||||
},
|
||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
|
||||
|
||||
def use_dataset_database_handler(
|
||||
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||
):
|
||||
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||
"handler_instance": dataset_database_handler,
|
||||
"handler_provider": dataset_database_provider,
|
||||
}
|
||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
graph_dataset_database_handler: str = "kuzu"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||
|
||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
|||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Kuzu Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Kuzu instance
|
||||
|
||||
"""
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "kuzu":
|
||||
raise ValueError(
|
||||
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "kuzu",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||
)
|
||||
graph_file_path = os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
)
|
||||
graph_engine = create_graph_engine(
|
||||
graph_database_provider=dataset_database.graph_database_provider,
|
||||
graph_database_url=dataset_database.graph_database_url,
|
||||
graph_database_name=dataset_database.graph_database_name,
|
||||
graph_database_key=dataset_database.graph_database_key,
|
||||
graph_file_path=graph_file_path,
|
||||
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
graph_dataset_database_handler="",
|
||||
graph_database_port="",
|
||||
)
|
||||
await graph_engine.delete_graph()
|
||||
|
|
@ -2005,3 +2005,134 @@ class KuzuAdapter(GraphDBInterface):
|
|||
time_ids_list = [item[0] for item in time_nodes]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
|
||||
with keys: 'start_node', 'relationship_properties', 'end_node'.
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If offset or limit are negative.
|
||||
- Exception: Re-raises any exceptions from query execution.
|
||||
"""
|
||||
if offset < 0:
|
||||
raise ValueError(f"Offset must be non-negative, got {offset}")
|
||||
if limit < 0:
|
||||
raise ValueError(f"Limit must be non-negative, got {limit}")
|
||||
|
||||
query = """
|
||||
MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
|
||||
RETURN {
|
||||
start_node: {
|
||||
id: start_node.id,
|
||||
name: start_node.name,
|
||||
type: start_node.type,
|
||||
properties: start_node.properties
|
||||
},
|
||||
relationship_properties: {
|
||||
relationship_name: relationship.relationship_name,
|
||||
properties: relationship.properties
|
||||
},
|
||||
end_node: {
|
||||
id: end_node.id,
|
||||
name: end_node.name,
|
||||
type: end_node.type,
|
||||
properties: end_node.properties
|
||||
}
|
||||
} AS triplet
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute triplet query: {str(e)}")
|
||||
logger.error(f"Query: {query}")
|
||||
logger.error(f"Parameters: offset={offset}, limit={limit}")
|
||||
raise
|
||||
|
||||
triplets = []
|
||||
for idx, row in enumerate(results):
|
||||
try:
|
||||
if not row or len(row) == 0:
|
||||
logger.warning(f"Skipping empty row at index {idx} in triplet batch")
|
||||
continue
|
||||
|
||||
if not isinstance(row[0], dict):
|
||||
logger.warning(
|
||||
f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
|
||||
)
|
||||
continue
|
||||
|
||||
triplet = row[0]
|
||||
|
||||
if "start_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["start_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
|
||||
|
||||
if "relationship_properties" not in triplet:
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["relationship_properties"], dict):
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
|
||||
)
|
||||
continue
|
||||
|
||||
rel_props = triplet["relationship_properties"].copy()
|
||||
relationship_name = rel_props.get("relationship_name") or ""
|
||||
|
||||
if rel_props.get("properties"):
|
||||
try:
|
||||
parsed_props = json.loads(rel_props["properties"])
|
||||
if isinstance(parsed_props, dict):
|
||||
rel_props.update(parsed_props)
|
||||
del rel_props["properties"]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed relationship properties is not a dict for triplet at index {idx}"
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
|
||||
)
|
||||
|
||||
rel_props["relationship_name"] = relationship_name
|
||||
triplet["relationship_properties"] = rel_props
|
||||
|
||||
if "end_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["end_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
|
||||
|
||||
triplets.append(triplet)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
return triplets
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
import base64
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User, DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||
|
||||
Improvements needed to be production ready:
|
||||
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||
a secret manager or a similar system should be used instead.
|
||||
|
||||
Quality of life improvements:
|
||||
- Allow configuration of different Neo4j Aura plans and regions.
|
||||
- Requests should be made async, currently a blocking requests library is used.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "neo4j":
|
||||
raise ValueError(
|
||||
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials and encryption
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||
)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(
|
||||
status_url, headers=headers
|
||||
) # TODO: Use async requests with httpx
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
|
||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "neo4j_aura_dev",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": encrypted_db_password_string,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||
In this case, decrypt the password stored in the database.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||
"""
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
graph_db_password = cipher.decrypt(
|
||||
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||
).decode()
|
||||
|
||||
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||
graph_db_password
|
||||
)
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
pass
|
||||
|
|
@ -8,7 +8,7 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple, Coroutine
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
|
|
@ -1527,3 +1527,25 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (start_node:`{BASE_LABEL}`)-[relationship]->(end_node:`{BASE_LABEL}`)
|
||||
RETURN start_node, properties(relationship) AS relationship_properties, end_node
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -1 +1,4 @@
|
|||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||
from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
|
||||
from .get_vector_dataset_database_handler import get_vector_dataset_database_handler
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
||||
return handler
|
||||
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
|
|
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
graph_config = get_graph_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _existing_dataset_database(
|
||||
dataset_id: UUID,
|
||||
user: User,
|
||||
) -> Optional[DatasetDatabase]:
|
||||
"""
|
||||
Check if a DatasetDatabase row already exists for the given owner + dataset.
|
||||
Return None if it doesn't exist, return the row if it does.
|
||||
Args:
|
||||
dataset_id:
|
||||
user:
|
||||
|
||||
Returns:
|
||||
DatasetDatabase or None
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
return existing
|
||||
|
||||
|
||||
async def get_or_create_dataset_database(
|
||||
dataset: Union[str, UUID],
|
||||
user: User,
|
||||
|
|
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
|
|||
• If the row already exists, it is fetched and returned.
|
||||
• Otherwise a new one is created atomically and returned.
|
||||
|
||||
DatasetDatabase row contains connection and provider info for vector and graph databases.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user : User
|
||||
|
|
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
# If dataset is given as name make sure the dataset is created first
|
||||
if isinstance(dataset, str):
|
||||
async with db_engine.get_async_session() as session:
|
||||
await create_dataset(dataset, user, session)
|
||||
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
if graph_config.graph_database_provider == "kuzu":
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
else:
|
||||
graph_db_name = f"{dataset_id}"
|
||||
# If dataset database already exists return it
|
||||
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
|
||||
if existing_dataset_database:
|
||||
return existing_dataset_database
|
||||
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Determine vector database URL
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||
else:
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
# Determine graph database URL
|
||||
graph_config_dict = await _get_graph_db_info(dataset_id, user)
|
||||
vector_config_dict = await _get_vector_db_info(dataset_id, user)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
if isinstance(dataset, str):
|
||||
dataset = await create_dataset(dataset, user, session)
|
||||
|
||||
# Try to fetch an existing row first
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# If there are no existing rows build a new row
|
||||
record = DatasetDatabase(
|
||||
owner_id=user.id,
|
||||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
**graph_config_dict, # Unpack graph db config
|
||||
**vector_config_dict, # Unpack vector db config
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
||||
return handler
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
|
||||
get_graph_dataset_database_handler,
|
||||
)
|
||||
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
|
||||
get_vector_dataset_database_handler,
|
||||
)
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
async def resolve_dataset_database_connection_info(
|
||||
dataset_database: DatasetDatabase,
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve the connection info for the given DatasetDatabase instance.
|
||||
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance
|
||||
Returns:
|
||||
DatasetDatabase instance with resolved connection info
|
||||
"""
|
||||
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
|
||||
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
|
||||
dataset_database = await vector_dataset_database_handler[
|
||||
"handler_instance"
|
||||
].resolve_dataset_connection_info(dataset_database)
|
||||
dataset_database = await graph_dataset_database_handler[
|
||||
"handler_instance"
|
||||
].resolve_dataset_connection_info(dataset_database)
|
||||
return dataset_database
|
||||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
|||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
vector_dataset_database_handler: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
|||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("FastembedEmbeddingEngine")
|
||||
|
|
@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
if self.mock:
|
||||
return [[0.0] * self.dimensions for _ in text]
|
||||
else:
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
|
||||
return list(embeddings)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -109,13 +110,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
||||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
|||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||
HuggingFaceTokenizer,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger("OllamaEmbeddingEngine")
|
||||
|
|
@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -120,11 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return data["embeddings"][0]
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"][0]
|
||||
else:
|
||||
return data["data"][0]["embedding"]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,544 +0,0 @@
|
|||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class EmbeddingRateLimiter:
|
||||
"""
|
||||
Rate limiter for embedding API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all embedding requests. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
||||
but uses a separate key to track embedding API calls independently.
|
||||
|
||||
Public Methods:
|
||||
- get_instance
|
||||
- reset_instance
|
||||
- hit_limit
|
||||
- wait_if_needed
|
||||
- async_wait_if_needed
|
||||
|
||||
Instance Variables:
|
||||
- enabled
|
||||
- requests_limit
|
||||
- interval_seconds
|
||||
- request_times
|
||||
- lock
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method ensures that only one instance of the class exists and
|
||||
is thread-safe. It lazily initializes the instance if it doesn't
|
||||
already exist.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The singleton instance of the EmbeddingRateLimiter class.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls.lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
"""
|
||||
Reset the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method is thread-safe and sets the instance to None, allowing
|
||||
for a new instance to be created when requested again.
|
||||
"""
|
||||
with cls.lock:
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
config = get_llm_config()
|
||||
self.enabled = config.embedding_rate_limit_enabled
|
||||
self.requests_limit = config.embedding_rate_limit_requests
|
||||
self.interval_seconds = config.embedding_rate_limit_interval
|
||||
self.request_times = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logging.info(
|
||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Check if the current request would exceed the rate limit.
|
||||
|
||||
This method checks if the rate limiter is enabled and evaluates
|
||||
the number of requests made in the elapsed interval.
|
||||
|
||||
Returns:
|
||||
- bool: True if the rate limit would be exceeded, False otherwise.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the rate limit would be exceeded, otherwise False.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self.lock:
|
||||
# Remove expired request times
|
||||
cutoff_time = current_time - self.interval_seconds
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if adding a new request would exceed the limit
|
||||
if len(self.request_times) >= self.requests_limit:
|
||||
logger.info(
|
||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
|
||||
# Otherwise, we're under the limit
|
||||
return False
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Block until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
time.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
|
||||
def embedding_rate_limit_sync(func):
|
||||
"""
|
||||
Apply rate limiting to a synchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Function to decorate with rate limiting logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrap the given function with rate limiting logic to control the embedding API usage.
|
||||
|
||||
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
||||
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
||||
updates the request count and proceeds to call the original function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Variable length argument list for the wrapped function.
|
||||
- **kwargs: Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if rate limiting conditions are met.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
limiter.wait_if_needed()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_rate_limit_async(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Async function to decorate.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated async function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle function calls with embedding rate limiting.
|
||||
|
||||
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
||||
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
||||
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
||||
call.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function after handling rate limiting.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
await limiter.async_wait_if_needed()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry with exponential backoff for synchronous embedding functions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries before giving up. (default 5)
|
||||
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
||||
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
||||
0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A decorator that retries the wrapped function on rate limit errors, applying
|
||||
exponential backoff with jitter.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Wraps a function to apply retry logic on rate limit errors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: The function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapped function with retry logic applied.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Retry the execution of a function with backoff on failure due to rate limit errors.
|
||||
|
||||
This wrapper function will call the specified function and if it raises an exception, it
|
||||
will handle retries according to defined conditions. It will check the environment for a
|
||||
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
||||
during tests. If the error is identified as a rate limit error, it will apply an
|
||||
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
||||
retries. If the retries are exhausted, it raises the last encountered error.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if successful; otherwise, raises the last
|
||||
error encountered after maximum retries are exhausted.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
time.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry logic with exponential backoff for asynchronous embedding functions.
|
||||
|
||||
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
||||
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
||||
It allows for a maximum number of retries before giving up and raising the last error
|
||||
encountered.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
||||
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
||||
limit error. (default 1.0)
|
||||
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
||||
issues on retries. (default 0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns a decorated asynchronous function that implements the retry logic on rate
|
||||
limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: An asynchronous function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapper function that manages the retry behavior for the wrapped async
|
||||
function.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
||||
not retry on errors.
|
||||
It attempts to call the wrapped function until it succeeds or the maximum number of
|
||||
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
||||
determine if a retry is needed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped async function if successful; raises the last
|
||||
encountered error if all retries fail.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -193,6 +193,8 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
for (data_point_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
await (
|
||||
collection.merge_insert("id")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with LanceDB Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
base_config = get_base_config()
|
||||
|
||||
if vector_config.vector_db_provider != "lancedb":
|
||||
raise ValueError(
|
||||
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||
)
|
||||
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_dataset_database_handler": "lancedb",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
vector_engine = create_vector_engine(
|
||||
vector_db_provider=dataset_database.vector_database_provider,
|
||||
vector_db_url=dataset_database.vector_database_url,
|
||||
vector_db_key=dataset_database.vector_database_key,
|
||||
vector_db_name=dataset_database.vector_database_name,
|
||||
)
|
||||
await vector_engine.prune()
|
||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
|||
from abc import abstractmethod
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
from uuid import UUID
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
|
|
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
|||
- Any: The schema object suitable for this vector database
|
||||
"""
|
||||
return model_type
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||
"""
|
||||
Delete the vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user: User object
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ class S3Config(BaseSettings):
|
|||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
aws_session_token: Optional[str] = None
|
||||
aws_profile_name: Optional[str] = None
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class LLMGateway:
|
|||
|
||||
@staticmethod
|
||||
def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> Coroutine:
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework.upper() == "BAML":
|
||||
|
|
@ -31,7 +31,10 @@ class LLMGateway:
|
|||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.acreate_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -74,6 +74,41 @@ class LLMConfig(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def strip_quotes_from_strings(self) -> "LLMConfig":
|
||||
"""
|
||||
Strip surrounding quotes from specific string fields that often come from
|
||||
environment variables with extra quotes (e.g., via Docker's --env-file).
|
||||
|
||||
Only applies to known config keys where quotes are invalid or cause issues.
|
||||
"""
|
||||
string_fields_to_strip = [
|
||||
"llm_api_key",
|
||||
"llm_endpoint",
|
||||
"llm_api_version",
|
||||
"baml_llm_api_key",
|
||||
"baml_llm_endpoint",
|
||||
"baml_llm_api_version",
|
||||
"fallback_api_key",
|
||||
"fallback_endpoint",
|
||||
"fallback_model",
|
||||
"llm_provider",
|
||||
"llm_model",
|
||||
"baml_llm_provider",
|
||||
"baml_llm_model",
|
||||
]
|
||||
|
||||
cls = self.__class__
|
||||
for field_name in string_fields_to_strip:
|
||||
if field_name not in cls.model_fields:
|
||||
continue
|
||||
value = getattr(self, field_name, None)
|
||||
if isinstance(value, str) and len(value) >= 2:
|
||||
if value[0] == value[-1] and value[0] in ("'", '"'):
|
||||
setattr(self, field_name, value[1:-1])
|
||||
|
||||
return self
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Initialize the BAML registry after the model is created."""
|
||||
# Check if BAML is selected as structured output framework but not available
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
|
|||
|
||||
|
||||
async def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
||||
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
|
||||
):
|
||||
if custom_prompt:
|
||||
system_prompt = custom_prompt
|
||||
|
|
@ -30,7 +30,7 @@ async def extract_content_graph(
|
|||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
content, system_prompt, response_model, **kwargs
|
||||
)
|
||||
|
||||
return content_graph
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
|
||||
create_dynamic_baml_type,
|
||||
|
|
@ -10,12 +18,18 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type
|
|||
TypeBuilder,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
import logging
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
):
|
||||
|
|
@ -45,11 +59,12 @@ async def acreate_structured_output(
|
|||
tb = TypeBuilder()
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
||||
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
|
||||
# Transform BAML response to proper pydantic reponse model
|
||||
if response_model is str:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from tenacity import (
|
|||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -45,13 +46,13 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
|
@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""Bedrock LLM adapter module."""
|
||||
|
||||
from .adapter import BedrockAdapter
|
||||
|
||||
__all__ = ["BedrockAdapter"]
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.exceptions import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
MissingSystemPromptPathError,
|
||||
)
|
||||
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
sleep_and_retry_async,
|
||||
sleep_and_retry_sync,
|
||||
)
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class BedrockAdapter(LLMInterface):
|
||||
"""
|
||||
Adapter for AWS Bedrock API with support for three authentication methods:
|
||||
1. API Key (Bearer Token)
|
||||
2. AWS Credentials (access key + secret key)
|
||||
3. AWS Profile (boto3 credential chain)
|
||||
"""
|
||||
|
||||
name = "Bedrock"
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_schema_mode"
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str = None,
|
||||
max_completion_tokens: int = 16384,
|
||||
streaming: bool = False,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
def _create_bedrock_request(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> dict:
|
||||
"""Create Bedrock request with authentication."""
|
||||
|
||||
request_params = {
|
||||
"model": self.model,
|
||||
"custom_llm_provider": "bedrock",
|
||||
"drop_params": True,
|
||||
"messages": [
|
||||
{"role": "user", "content": text_input},
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
"response_model": response_model,
|
||||
"max_retries": self.MAX_RETRIES,
|
||||
"max_completion_tokens": self.max_completion_tokens,
|
||||
"stream": self.streaming,
|
||||
}
|
||||
|
||||
s3_config = get_s3_config()
|
||||
|
||||
# Add authentication parameters
|
||||
if self.api_key:
|
||||
request_params["api_key"] = self.api_key
|
||||
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
|
||||
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
|
||||
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
|
||||
if s3_config.aws_session_token:
|
||||
request_params["aws_session_token"] = s3_config.aws_session_token
|
||||
elif s3_config.aws_profile_name:
|
||||
request_params["aws_profile_name"] = s3_config.aws_profile_name
|
||||
|
||||
if s3_config.aws_region:
|
||||
request_params["aws_region_name"] = s3_config.aws_region
|
||||
|
||||
# Add optional parameters
|
||||
if s3_config.aws_bedrock_runtime_endpoint:
|
||||
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
|
||||
|
||||
return request_params
|
||||
|
||||
@observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""Generate structured output from AWS Bedrock API."""
|
||||
|
||||
try:
|
||||
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||
return await self.aclient.chat.completions.create(**request_params)
|
||||
|
||||
except (
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
@observe
|
||||
@sleep_and_retry_sync()
|
||||
@rate_limit_sync
|
||||
def create_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""Generate structured output from AWS Bedrock API (synchronous)."""
|
||||
|
||||
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||
return self.client.chat.completions.create(**request_params)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
return formatted_prompt
|
||||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,13 +74,13 @@ class GeminiAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
|
@ -105,24 +106,25 @@ class GeminiAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=2,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -140,23 +142,24 @@ class GeminiAdapter(LLMInterface):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,13 +74,13 @@ class GenericAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
|
@ -105,23 +106,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -139,23 +141,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) from error
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class LLMProvider(Enum):
|
|||
- CUSTOM: Represents a custom provider option.
|
||||
- GEMINI: Represents the Gemini provider.
|
||||
- MISTRAL: Represents the Mistral AI provider.
|
||||
- BEDROCK: Represents the AWS Bedrock provider.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
|
@ -32,6 +33,7 @@ class LLMProvider(Enum):
|
|||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
MISTRAL = "mistral"
|
||||
BEDROCK = "bedrock"
|
||||
|
||||
|
||||
def get_llm_client(raise_api_key_error: bool = True):
|
||||
|
|
@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
if llm_config.llm_api_key is None and raise_api_key_error:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
|
|
@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.BEDROCK:
|
||||
# if llm_config.llm_api_key is None and raise_api_key_error:
|
||||
# raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
|
||||
BedrockAdapter,
|
||||
)
|
||||
|
||||
return BedrockAdapter(
|
||||
model=llm_config.llm_model,
|
||||
api_key=llm_config.llm_api_key,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
streaming=llm_config.llm_streaming,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnsupportedLLMProviderError(provider)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
import logging
|
||||
from tenacity import (
|
||||
|
|
@ -62,13 +63,13 @@ class MistralAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from the user query.
|
||||
|
|
@ -97,13 +98,14 @@ class MistralAdapter(LLMInterface):
|
|||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=2,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
@ -68,13 +70,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a structured output from the LLM using the provided text and system prompt.
|
||||
|
|
@ -95,33 +97,33 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: A structured output that conforms to the specified response model.
|
||||
"""
|
||||
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input_file: str) -> str:
|
||||
async def create_transcript(self, input_file: str, **kwargs) -> str:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -160,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def transcribe_image(self, input_file: str) -> str:
|
||||
async def transcribe_image(self, input_file: str, **kwargs) -> str:
|
||||
"""
|
||||
Transcribe content from an image using base64 encoding.
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -105,13 +106,13 @@ class OpenAIAdapter(LLMInterface):
|
|||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
|
@ -135,34 +136,9 @@ class OpenAIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -173,11 +149,40 @@ class OpenAIAdapter(LLMInterface):
|
|||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -202,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
reraise=True,
|
||||
)
|
||||
def create_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
|
@ -242,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@retry(
|
||||
|
|
@ -251,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input):
|
||||
async def create_transcript(self, input, **kwargs):
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -278,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return transcription
|
||||
|
|
@ -289,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def transcribe_image(self, input) -> BaseModel:
|
||||
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
||||
"""
|
||||
Generate a transcription of an image from a user query.
|
||||
|
||||
|
|
@ -334,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_version=self.api_version,
|
||||
max_completion_tokens=300,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
53
cognee/memify_pipelines/create_triplet_embeddings.py
Normal file
53
cognee/memify_pipelines/create_triplet_embeddings.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Any
|
||||
|
||||
from cognee import memify
|
||||
from cognee.context_global_variables import (
|
||||
set_database_global_context_variables,
|
||||
)
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
||||
from cognee.tasks.storage import index_data_points
|
||||
|
||||
logger = get_logger("create_triplet_embeddings")
|
||||
|
||||
|
||||
async def create_triplet_embeddings(
|
||||
user: User,
|
||||
dataset: str = "main_dataset",
|
||||
run_in_background: bool = False,
|
||||
triplets_batch_size: int = 100,
|
||||
) -> dict[str, Any]:
|
||||
dataset_to_write = await get_authorized_existing_datasets(
|
||||
user=user, datasets=[dataset], permission_type="write"
|
||||
)
|
||||
|
||||
if not dataset_to_write:
|
||||
raise CogneeValidationError(
|
||||
message=f"User does not have write access to dataset: {dataset}",
|
||||
log=False,
|
||||
)
|
||||
|
||||
await set_database_global_context_variables(
|
||||
dataset_to_write[0].id, dataset_to_write[0].owner_id
|
||||
)
|
||||
|
||||
extraction_tasks = [Task(get_triplet_datapoints, triplets_batch_size=triplets_batch_size)]
|
||||
|
||||
enrichment_tasks = [
|
||||
Task(index_data_points, task_config={"batch_size": triplets_batch_size}),
|
||||
]
|
||||
|
||||
result = await memify(
|
||||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
dataset=dataset_to_write[0].id,
|
||||
data=[{}],
|
||||
user=user,
|
||||
run_in_background=run_in_background,
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -8,12 +8,14 @@ import os
|
|||
class CognifyConfig(BaseSettings):
|
||||
classification_model: object = DefaultContentPrediction
|
||||
summarization_model: object = SummarizedContent
|
||||
triplet_embedding: bool = False
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"classification_model": self.classification_model,
|
||||
"summarization_model": self.summarization_model,
|
||||
"triplet_embedding": self.triplet_embedding,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue