Merge branch 'dev' into COG-748
This commit is contained in:
commit
93bca8ee5f
349 changed files with 6418 additions and 4534 deletions
2
.github/workflows/cd.yaml
vendored
2
.github/workflows/cd.yaml
vendored
|
|
@ -17,7 +17,7 @@ jobs:
|
|||
|
||||
publish_docker_to_ecr:
|
||||
name: Publish Cognee Docker image
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
|
|
|||
2
.github/workflows/cd_prd.yaml
vendored
2
.github/workflows/cd_prd.yaml
vendored
|
|
@ -17,7 +17,7 @@ jobs:
|
|||
|
||||
publish_docker_to_ecr:
|
||||
name: Publish Docker PromethAI image
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
|
|
|||
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
|
|
@ -9,7 +9,7 @@ jobs:
|
|||
|
||||
build_docker:
|
||||
name: Build Cognee Backend Docker App Image
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out Cognee code
|
||||
uses: actions/checkout@v3
|
||||
|
|
|
|||
2
.github/workflows/community_greetings.yml
vendored
2
.github/workflows/community_greetings.yml
vendored
|
|
@ -4,7 +4,7 @@ on: [pull_request, issues]
|
|||
|
||||
jobs:
|
||||
greeting:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/first-interaction@v1
|
||||
with:
|
||||
|
|
|
|||
2
.github/workflows/docker_compose.yml
vendored
2
.github/workflows/docker_compose.yml
vendored
|
|
@ -12,7 +12,7 @@ on:
|
|||
|
||||
jobs:
|
||||
docker-compose-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
|
|||
4
.github/workflows/dockerhub.yml
vendored
4
.github/workflows/dockerhub.yml
vendored
|
|
@ -7,7 +7,7 @@ on:
|
|||
|
||||
jobs:
|
||||
docker-build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
run: |
|
||||
IMAGE_NAME=cognee/cognee
|
||||
TAG_VERSION="${BRANCH_NAME}-${COMMIT_SHA}"
|
||||
|
||||
|
||||
echo "Building image: ${IMAGE_NAME}:${TAG_VERSION}"
|
||||
docker buildx build \
|
||||
--platform linux/amd64,linux/arm64 \
|
||||
|
|
|
|||
92
.github/workflows/profiling.yaml
vendored
92
.github/workflows/profiling.yaml
vendored
|
|
@ -7,7 +7,7 @@ on:
|
|||
|
||||
jobs:
|
||||
profiler:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
# Checkout the code from the repository with full history
|
||||
|
|
@ -47,7 +47,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
@ -82,50 +82,50 @@ jobs:
|
|||
poetry run pyinstrument --renderer json -o base_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
|
||||
# Run profiler on head branch
|
||||
- name: Run profiler on head branch
|
||||
env:
|
||||
HEAD_SHA: ${{ env.HEAD_SHA }}
|
||||
run: |
|
||||
echo "Profiling the head branch for code_graph_pipeline.py"
|
||||
echo "Checking out head SHA: $HEAD_SHA"
|
||||
git checkout $HEAD_SHA
|
||||
echo "This is the working directory: $PWD"
|
||||
# Ensure the script is executable
|
||||
chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# Run Scalene
|
||||
poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
|
||||
# Compare profiling results
|
||||
- name: Compare profiling results
|
||||
run: |
|
||||
python -c '
|
||||
import json
|
||||
try:
|
||||
with open("base_results.json") as f:
|
||||
base = json.load(f)
|
||||
with open("head_results.json") as f:
|
||||
head = json.load(f)
|
||||
cpu_diff = head.get("total_cpu_samples_python", 0) - base.get("total_cpu_samples_python", 0)
|
||||
memory_diff = head.get("malloc_samples", 0) - base.get("malloc_samples", 0)
|
||||
results = [
|
||||
f"CPU Usage Difference: {cpu_diff}",
|
||||
f"Memory Usage Difference: {memory_diff} bytes"
|
||||
]
|
||||
with open("profiling_diff.txt", "w") as f:
|
||||
f.write("\\n".join(results) + "\\n")
|
||||
print("\\n".join(results)) # Print results to terminal
|
||||
except Exception as e:
|
||||
error_message = f"Error comparing profiling results: {e}"
|
||||
with open("profiling_diff.txt", "w") as f:
|
||||
f.write(error_message + "\\n")
|
||||
print(error_message) # Print error to terminal
|
||||
'
|
||||
|
||||
- name: Upload profiling diff artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: profiling-diff
|
||||
path: profiling_diff.txt
|
||||
# - name: Run profiler on head branch
|
||||
# env:
|
||||
# HEAD_SHA: ${{ env.HEAD_SHA }}
|
||||
# run: |
|
||||
# echo "Profiling the head branch for code_graph_pipeline.py"
|
||||
# echo "Checking out head SHA: $HEAD_SHA"
|
||||
# git checkout $HEAD_SHA
|
||||
# echo "This is the working directory: $PWD"
|
||||
# # Ensure the script is executable
|
||||
# chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# # Run Scalene
|
||||
# poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
#
|
||||
# # Compare profiling results
|
||||
# - name: Compare profiling results
|
||||
# run: |
|
||||
# python -c '
|
||||
# import json
|
||||
# try:
|
||||
# with open("base_results.json") as f:
|
||||
# base = json.load(f)
|
||||
# with open("head_results.json") as f:
|
||||
# head = json.load(f)
|
||||
# cpu_diff = head.get("total_cpu_samples_python", 0) - base.get("total_cpu_samples_python", 0)
|
||||
# memory_diff = head.get("malloc_samples", 0) - base.get("malloc_samples", 0)
|
||||
# results = [
|
||||
# f"CPU Usage Difference: {cpu_diff}",
|
||||
# f"Memory Usage Difference: {memory_diff} bytes"
|
||||
# ]
|
||||
# with open("profiling_diff.txt", "w") as f:
|
||||
# f.write("\\n".join(results) + "\\n")
|
||||
# print("\\n".join(results)) # Print results to terminal
|
||||
# except Exception as e:
|
||||
# error_message = f"Error comparing profiling results: {e}"
|
||||
# with open("profiling_diff.txt", "w") as f:
|
||||
# f.write(error_message + "\\n")
|
||||
# print(error_message) # Print error to terminal
|
||||
# '
|
||||
#
|
||||
# - name: Upload profiling diff artifact
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: profiling-diff
|
||||
# path: profiling_diff.txt
|
||||
|
||||
# Post results to the pull request
|
||||
# - name: Post profiling results to PR
|
||||
|
|
|
|||
4
.github/workflows/py_lint.yml
vendored
4
.github/workflows/py_lint.yml
vendored
|
|
@ -16,8 +16,8 @@ jobs:
|
|||
fail-fast: true
|
||||
matrix:
|
||||
os:
|
||||
- ubuntu-latest
|
||||
python-version: ["3.8.x", "3.9.x", "3.10.x", "3.11.x"]
|
||||
- ubuntu-22.04
|
||||
python-version: ["3.10.x", "3.11.x"]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
|||
2
.github/workflows/release_discord_action.yml
vendored
2
.github/workflows/release_discord_action.yml
vendored
|
|
@ -6,7 +6,7 @@ on:
|
|||
|
||||
jobs:
|
||||
github-releases-to-discord:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
|
|
|||
6
.github/workflows/reusable_notebook.yml
vendored
6
.github/workflows/reusable_notebook.yml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
|
||||
run_notebook_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -36,7 +36,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
@ -58,4 +58,4 @@ jobs:
|
|||
--to notebook \
|
||||
--execute ${{ inputs.notebook-location }} \
|
||||
--output executed_notebook.ipynb \
|
||||
--ExecutePreprocessor.timeout=1200
|
||||
--ExecutePreprocessor.timeout=1200
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
|
||||
run_notebook_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -33,10 +33,10 @@ jobs:
|
|||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
python-version: '3.12.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
@ -49,7 +49,8 @@ jobs:
|
|||
- name: Execute Python Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
PYTHONFAULTHANDLER: 1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
run: poetry run python ${{ inputs.example-location }}
|
||||
run: poetry run python ${{ inputs.example-location }}
|
||||
|
|
|
|||
2
.github/workflows/ruff_format.yaml
vendored
2
.github/workflows/ruff_format.yaml
vendored
|
|
@ -3,7 +3,7 @@ on: [ pull_request ]
|
|||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/ruff-action@v2
|
||||
|
|
|
|||
2
.github/workflows/ruff_lint.yaml
vendored
2
.github/workflows/ruff_lint.yaml
vendored
|
|
@ -3,7 +3,7 @@ on: [ pull_request ]
|
|||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/ruff-action@v2
|
||||
|
|
|
|||
5
.github/workflows/test_deduplication.yml
vendored
5
.github/workflows/test_deduplication.yml
vendored
|
|
@ -16,8 +16,7 @@ env:
|
|||
jobs:
|
||||
run_deduplication_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -46,7 +45,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
5
.github/workflows/test_milvus.yml
vendored
5
.github/workflows/test_milvus.yml
vendored
|
|
@ -17,8 +17,7 @@ jobs:
|
|||
|
||||
run_milvus:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
@ -36,7 +35,7 @@ jobs:
|
|||
|
||||
- name: Install Poetry
|
||||
# https://github.com/snok/install-poetry#running-on-windows
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
5
.github/workflows/test_neo4j.yml
vendored
5
.github/workflows/test_neo4j.yml
vendored
|
|
@ -15,8 +15,7 @@ env:
|
|||
jobs:
|
||||
run_neo4j_integration_test:
|
||||
name: test
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
|
@ -32,7 +31,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
5
.github/workflows/test_pgvector.yml
vendored
5
.github/workflows/test_pgvector.yml
vendored
|
|
@ -17,8 +17,7 @@ jobs:
|
|||
|
||||
run_pgvector_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -47,7 +46,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
6
.github/workflows/test_python_3_10.yml
vendored
6
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -14,11 +14,9 @@ env:
|
|||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
|
||||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
@ -36,7 +34,7 @@ jobs:
|
|||
|
||||
- name: Install Poetry
|
||||
# https://github.com/snok/install-poetry#running-on-windows
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
5
.github/workflows/test_python_3_11.yml
vendored
5
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -17,8 +17,7 @@ jobs:
|
|||
|
||||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
@ -36,7 +35,7 @@ jobs:
|
|||
|
||||
- name: Install Poetry
|
||||
# https://github.com/snok/install-poetry#running-on-windows
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
name: test | python 3.9
|
||||
name: test | python 3.12
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
|
@ -17,8 +17,7 @@ jobs:
|
|||
|
||||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
@ -32,11 +31,11 @@ jobs:
|
|||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9.x'
|
||||
python-version: '3.12.x'
|
||||
|
||||
- name: Install Poetry
|
||||
# https://github.com/snok/install-poetry#running-on-windows
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
6
.github/workflows/test_qdrant.yml
vendored
6
.github/workflows/test_qdrant.yml
vendored
|
|
@ -17,9 +17,7 @@ jobs:
|
|||
|
||||
run_qdrant_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -34,7 +32,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
6
.github/workflows/test_weaviate.yml
vendored
6
.github/workflows/test_weaviate.yml
vendored
|
|
@ -17,9 +17,7 @@ jobs:
|
|||
|
||||
run_weaviate_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -34,7 +32,7 @@ jobs:
|
|||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.3.2
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ RUN pip install poetry
|
|||
RUN poetry config virtualenvs.create false
|
||||
|
||||
# Install the dependencies
|
||||
RUN poetry install --all-extras --no-root --no-dev
|
||||
RUN poetry install --all-extras --no-root --without dev
|
||||
|
||||
|
||||
# Set the PYTHONPATH environment variable to include the /app directory
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from logging.config import fileConfig
|
|||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from alembic import context
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
@ -20,7 +21,7 @@ if config.config_file_name is not None:
|
|||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
|
|
@ -83,12 +84,11 @@ def run_migrations_online() -> None:
|
|||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
db_config = get_relational_config()
|
||||
LocalStorage.ensure_directory_exists(db_config.db_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Revises: 8057ae7329c2
|
|||
Create Date: 2024-10-16 22:17:18.634638
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from sqlalchemy.util import await_only
|
||||
|
|
@ -13,8 +14,8 @@ from cognee.modules.users.methods import create_default_user, delete_user
|
|||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '482cd6517ce4'
|
||||
down_revision: Union[str, None] = '8057ae7329c2'
|
||||
revision: str = "482cd6517ce4"
|
||||
down_revision: Union[str, None] = "8057ae7329c2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
"""Initial migration
|
||||
|
||||
Revision ID: 8057ae7329c2
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2024-10-02 12:55:20.989372
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from sqlalchemy.util import await_only
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 3.4 KiB |
BIN
assets/cognee_logo.png
Normal file
BIN
assets/cognee_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 27 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 86 KiB |
1
cognee-mcp/.python-version
Normal file
1
cognee-mcp/.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.11.5
|
||||
|
|
@ -7,6 +7,7 @@ def main():
|
|||
"""Main entry point for the package."""
|
||||
asyncio.run(server.main())
|
||||
|
||||
|
||||
# Optionally expose other important items at package level
|
||||
__all__ = ["main", "server"]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ def node_to_string(node):
|
|||
# keys_to_keep = ["chunk_index", "topological_rank", "cut_type", "id", "text"]
|
||||
# keyset = set(keys_to_keep) & node.keys()
|
||||
# return "Node(" + " ".join([key + ": " + str(node[key]) + "," for key in keyset]) + ")"
|
||||
node_data = ", ".join([f"{key}: \"{value}\"" for key, value in node.items() if key in ["id", "name"]])
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
)
|
||||
|
||||
return f"Node({node_data})"
|
||||
|
||||
|
|
@ -52,9 +54,9 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||
"""
|
||||
return [
|
||||
types.Tool(
|
||||
name = "cognify",
|
||||
description = "Build knowledge graph from the input text.",
|
||||
inputSchema = {
|
||||
name="cognify",
|
||||
description="Build knowledge graph from the input text.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
|
|
@ -65,9 +67,9 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||
},
|
||||
),
|
||||
types.Tool(
|
||||
name = "search",
|
||||
description = "Search the knowledge graph.",
|
||||
inputSchema = {
|
||||
name="search",
|
||||
description="Search the knowledge graph.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
|
|
@ -76,9 +78,9 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||
},
|
||||
),
|
||||
types.Tool(
|
||||
name = "prune",
|
||||
description = "Reset the knowledge graph.",
|
||||
inputSchema = {
|
||||
name="prune",
|
||||
description="Reset the knowledge graph.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
|
|
@ -90,8 +92,7 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str,
|
||||
arguments: dict | None
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
"""
|
||||
Handle tool execution requests.
|
||||
|
|
@ -115,12 +116,12 @@ async def handle_call_tool(
|
|||
|
||||
await cognee.add(text)
|
||||
|
||||
await cognee.cognify(graph_model = graph_model)
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type = "text",
|
||||
text = "Ingested",
|
||||
type="text",
|
||||
text="Ingested",
|
||||
)
|
||||
]
|
||||
elif name == "search":
|
||||
|
|
@ -131,16 +132,14 @@ async def handle_call_tool(
|
|||
|
||||
search_query = arguments.get("query")
|
||||
|
||||
search_results = await cognee.search(
|
||||
SearchType.INSIGHTS, query_text = search_query
|
||||
)
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
|
||||
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type = "text",
|
||||
text = results,
|
||||
type="text",
|
||||
text=results,
|
||||
)
|
||||
]
|
||||
elif name == "prune":
|
||||
|
|
@ -151,8 +150,8 @@ async def handle_call_tool(
|
|||
|
||||
return [
|
||||
types.TextContent(
|
||||
type = "text",
|
||||
text = "Pruned",
|
||||
type="text",
|
||||
text="Pruned",
|
||||
)
|
||||
]
|
||||
else:
|
||||
|
|
@ -166,15 +165,16 @@ async def main():
|
|||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name = "cognee-mcp",
|
||||
server_version = "0.1.0",
|
||||
capabilities = server.get_capabilities(
|
||||
notification_options = NotificationOptions(),
|
||||
experimental_capabilities = {},
|
||||
server_name="cognee-mcp",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This is needed if you'd like to connect to a custom client
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
188
cognee-mcp/uv.lock
generated
188
cognee-mcp/uv.lock
generated
|
|
@ -321,6 +321,26 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed", size = 147925 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bokeh"
|
||||
version = "3.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "tornado" },
|
||||
{ name = "xyzservices" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/da/9d/cc9c561e1db8cbecc5cfad972159020700fff2339bdaa316498ace1cb04c/bokeh-3.6.2.tar.gz", hash = "sha256:2f3043d9ecb3d5dc2e8c0ebf8ad55727617188d4e534f3e7208b36357e352396", size = 6247610 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/12/2c266a0dc57379c60b4e73a2f93e71343db4170bf26c5a76a74e7d8bce2a/bokeh-3.6.2-py3-none-any.whl", hash = "sha256:fddc4b91f8b40178c0e3e83dfcc33886d7803a3a1f041a840834255e435a18c2", size = 6866799 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.35.84"
|
||||
|
|
@ -358,6 +378,34 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", size = 9524 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cairocffi"
|
||||
version = "1.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cffi" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/70/c5/1a4dc131459e68a173cbdab5fad6b524f53f9c1ef7861b7698e998b837cc/cairocffi-1.7.1.tar.gz", hash = "sha256:2e48ee864884ec4a3a34bfa8c9ab9999f688286eb714a15a43ec9d068c36557b", size = 88096 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/93/d8/ba13451aa6b745c49536e87b6bf8f629b950e84bd0e8308f7dc6883b67e2/cairocffi-1.7.1-py3-none-any.whl", hash = "sha256:9803a0e11f6c962f3b0ae2ec8ba6ae45e957a146a004697a1ac1bbf16b073b3f", size = 75611 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cairosvg"
|
||||
version = "2.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cairocffi" },
|
||||
{ name = "cssselect2" },
|
||||
{ name = "defusedxml" },
|
||||
{ name = "pillow" },
|
||||
{ name = "tinycss2" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d5/e6/ec5900b724e3c44af7f6f51f719919137284e5da4aabe96508baec8a1b40/CairoSVG-2.7.1.tar.gz", hash = "sha256:432531d72347291b9a9ebfb6777026b607563fd8719c46ee742db0aef7271ba0", size = 8399085 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/01/a5/1866b42151f50453f1a0d28fc4c39f5be5f412a2e914f33449c42daafdf1/CairoSVG-2.7.1-py3-none-any.whl", hash = "sha256:8a5222d4e6c3f86f1f7046b63246877a63b49923a1cd202184c3a634ef546b3b", size = 43235 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.12.14"
|
||||
|
|
@ -412,6 +460,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfgv"
|
||||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chardet"
|
||||
version = "5.2.0"
|
||||
|
|
@ -480,7 +537,7 @@ name = "click"
|
|||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
|
|
@ -497,7 +554,9 @@ dependencies = [
|
|||
{ name = "aiosqlite" },
|
||||
{ name = "alembic" },
|
||||
{ name = "anthropic" },
|
||||
{ name = "bokeh" },
|
||||
{ name = "boto3" },
|
||||
{ name = "cairosvg" },
|
||||
{ name = "datasets" },
|
||||
{ name = "dlt", extra = ["sqlalchemy"] },
|
||||
{ name = "fastapi" },
|
||||
|
|
@ -518,6 +577,7 @@ dependencies = [
|
|||
{ name = "numpy" },
|
||||
{ name = "openai" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pypdf" },
|
||||
|
|
@ -541,8 +601,10 @@ requires-dist = [
|
|||
{ name = "alembic", specifier = ">=1.13.3,<2.0.0" },
|
||||
{ name = "anthropic", specifier = ">=0.26.1,<0.27.0" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = "==0.30.0" },
|
||||
{ name = "bokeh", specifier = ">=3.6.2,<4.0.0" },
|
||||
{ name = "boto3", specifier = ">=1.26.125,<2.0.0" },
|
||||
{ name = "botocore", marker = "extra == 'filesystem'", specifier = ">=1.35.54,<2.0.0" },
|
||||
{ name = "cairosvg", specifier = ">=2.7.1,<3.0.0" },
|
||||
{ name = "datasets", specifier = "==3.1.0" },
|
||||
{ name = "deepeval", marker = "extra == 'deepeval'", specifier = ">=2.0.1,<3.0.0" },
|
||||
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
|
||||
|
|
@ -573,6 +635,7 @@ requires-dist = [
|
|||
{ name = "pandas", specifier = "==2.0.3" },
|
||||
{ name = "pgvector", marker = "extra == 'postgres'", specifier = ">=0.3.5,<0.4.0" },
|
||||
{ name = "posthog", marker = "extra == 'posthog'", specifier = ">=3.5.0,<4.0.0" },
|
||||
{ name = "pre-commit", specifier = ">=4.0.1,<5.0.0" },
|
||||
{ name = "psycopg2", marker = "extra == 'postgres'", specifier = ">=2.9.10,<3.0.0" },
|
||||
{ name = "pydantic", specifier = "==2.8.2" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.2.1,<3.0.0" },
|
||||
|
|
@ -885,6 +948,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/97/9b/443270b9210f13f6ef240eff73fd32e02d381e7103969dc66ce8e89ee901/cryptography-44.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:708ee5f1bafe76d041b53a4f95eb28cdeb8d18da17e597d46d7833ee59b97ede", size = 3202071 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cssselect2"
|
||||
version = "0.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "tinycss2" },
|
||||
{ name = "webencodings" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/fc/326cb6f988905998f09bb54a3f5d98d4462ba119363c0dfad29750d48c09/cssselect2-0.7.0.tar.gz", hash = "sha256:1ccd984dab89fc68955043aca4e1b03e0cf29cad9880f6e28e3ba7a74b14aa5a", size = 35888 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/3a/e39436efe51894243ff145a37c4f9a030839b97779ebcc4f13b3ba21c54e/cssselect2-0.7.0-py3-none-any.whl", hash = "sha256:fd23a65bfd444595913f02fc71f6b286c29261e354c41d722ca7a261a49b5969", size = 15586 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cycler"
|
||||
version = "0.12.1"
|
||||
|
|
@ -996,6 +1072,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/39/60/533ce66e28295e2b94267126a851ac091ad29a835a9827d1f9c30574fce4/deepeval-2.0.6-py3-none-any.whl", hash = "sha256:57302830ff9d3d16ad4f1961338c7b4453e48039ff131990f258880728f33b6b", size = 504101 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deprecated"
|
||||
version = "1.2.15"
|
||||
|
|
@ -1056,6 +1141,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.3.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
|
|
@ -1686,6 +1780,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d7/de/85a784bcc4a3779d1753a7ec2dee5de90e18c7bcf402e71b51fcf150b129/hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15", size = 12389 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.6.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1a/5f/05f0d167be94585d502b4adf8c7af31f1dc0b1c7e14f9938a88fdbbcf4a7/identify-2.6.3.tar.gz", hash = "sha256:62f5dae9b5fef52c84cc188514e9ea4f3f636b1d8799ab5ebc475471f9e47a02", size = 99179 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/f5/09644a3ad803fae9eca8efa17e1f2aef380c7f0b02f7ec4e8d446e51d64a/identify-2.6.3-py2.py3-none-any.whl", hash = "sha256:9edba65473324c2ea9684b1f944fe3191db3345e50b6d04571d10ed164f8d7bd", size = 99049 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
|
|
@ -2557,6 +2660,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
|
|
@ -2946,7 +3058,7 @@ name = "portalocker"
|
|||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
|
||||
wheels = [
|
||||
|
|
@ -2969,6 +3081,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d3/f2/5ee24cd69e2120bf87356c02ace0438b4e4fb78229fddcbf6f1c6be377d5/posthog-3.7.4-py2.py3-none-any.whl", hash = "sha256:21c18c6bf43b2de303ea4cd6e95804cc0f24c20cb2a96a8fd09da2ed50b62faa", size = 54777 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "4.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cfgv" },
|
||||
{ name = "identify" },
|
||||
{ name = "nodeenv" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "virtualenv" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.2.1"
|
||||
|
|
@ -3065,6 +3193,7 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/33/39/5a9a229bb5414abeb86e33b8fc8143ab0aecce5a7f698a53e31367d30caa/psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4", size = 1163736 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/16/4623fad6076448df21c1a870c93a9774ad8a7b4dd1660223b59082dd8fec/psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067", size = 1025113 },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/de/baed128ae0fc07460d9399d82e631ea31a1f171c0c4ae18f9808ac6759e3/psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e", size = 1163951 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/49/a6cfc94a9c483b1fa401fbcb23aca7892f60c7269c5ffa2ac408364f80dc/psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2", size = 2569060 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -4223,6 +4352,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/35/75/c4d8b2f0fe7dac22854d88a9c509d428e78ac4bf284bc54cfe83f75cc13b/time_machine-2.16.0-cp313-cp313-win_arm64.whl", hash = "sha256:4d3843143c46dddca6491a954bbd0abfd435681512ac343169560e9bab504129", size = 18047 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinycss2"
|
||||
version = "1.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "webencodings" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.0"
|
||||
|
|
@ -4257,12 +4398,30 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/59/45/a0daf161f7d6f36c3ea5fc0c2de619746cc3dd4c76402e9db545bd920f63/tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b", size = 501135 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/7e/71f604d8cea1b58f82ba3590290b66da1e72d840aeb37e0d5f7291bd30db/tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1", size = 436299 },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/44/87543a3b99016d0bf54fdaab30d24bf0af2e848f1d13d34a3a5380aabe16/tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803", size = 434253 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/fb/fdf679b4ce51bcb7210801ef4f11fdac96e9885daa402861751353beea6e/tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec", size = 437602 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/3b/e31aeffffc22b475a64dbeb273026a21b5b566f74dee48742817626c47dc/tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946", size = 436972 },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/55/b78a464de78051a30599ceb6983b01d8f732e6f69bf37b4ed07f642ac0fc/tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf", size = 437173 },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/5e/be4fb0d1684eb822c9a62fb18a3e44a06188f78aa466b2ad991d2ee31104/tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634", size = 437892 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/33/4f91fdd94ea36e1d796147003b490fe60a0215ac5737b6f9c65e160d4fe0/tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73", size = 437334 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/ae/c1b22d4524b0e10da2f29a176fb2890386f7bd1f63aacf186444873a88a0/tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c", size = 437261 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/25/36dbd49ab6d179bcfc4c6c093a51795a4f3bed380543a8242ac3517a1751/tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482", size = 438463 },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/cc/58b1adeb1bb46228442081e746fcdbc4540905c87e8add7c277540934edb/tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38", size = 438907 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
|
||||
wheels = [
|
||||
|
|
@ -4536,6 +4695,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/04/22/91b4bd36df27e651daedd93d03d5d3bb6029fdb0b55494e45ee46c36c570/validators-0.33.0-py3-none-any.whl", hash = "sha256:134b586a98894f8139865953899fc2daeb3d0c35569552c5518f089ae43ed075", size = 43298 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.28.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "distlib" },
|
||||
{ name = "filelock" },
|
||||
{ name = "platformdirs" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/75/53316a5a8050069228a2f6d11f32046cfa94fbb6cc3f08703f59b873de2e/virtualenv-20.28.0.tar.gz", hash = "sha256:2c9c3262bb8e7b87ea801d715fae4495e6032450c71d2309be9550e7364049aa", size = 7650368 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/f9/0919cf6f1432a8c4baa62511f8f8da8225432d22e83e3476f5be1a1edc6e/virtualenv-20.28.0-py3-none-any.whl", hash = "sha256:23eae1b4516ecd610481eda647f3a7c09aea295055337331bb4e6892ecce47b0", size = 4276702 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weaviate-client"
|
||||
version = "4.6.7"
|
||||
|
|
@ -4692,6 +4865,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/27/ee/518b72faa2073f5aa8e3262408d284892cb79cf2754ba0c3a5870645ef73/xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b", size = 26801 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xyzservices"
|
||||
version = "2024.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a0/16/ae87cbd2d6bfc40a419077521c35aadf5121725b7bee3d7c51b56f50958b/xyzservices-2024.9.0.tar.gz", hash = "sha256:68fb8353c9dbba4f1ff6c0f2e5e4e596bb9e1db7f94f4f7dfbcb26e25aa66fde", size = 1131900 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/d3/e07ce413d16ef64e885bea37551eac4c5ca3ddd440933f9c94594273d0d9/xyzservices-2024.9.0-py3-none-any.whl", hash = "sha256:776ae82b78d6e5ca63dd6a94abb054df8130887a4a308473b54a6bd364de8644", size = 85130 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.18.3"
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@ from .api.v1.config.config import config
|
|||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.prune import prune
|
||||
from .api.v1.search import SearchType, get_search_history, search
|
||||
from .api.v1.visualize import visualize
|
||||
from .shared.utils import create_cognee_style_network_with_logo
|
||||
|
||||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ from pydantic.alias_generators import to_camel, to_snake
|
|||
|
||||
class OutDTO(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator = to_camel,
|
||||
populate_by_name = True,
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
|
||||
class InDTO(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator = to_camel,
|
||||
populate_by_name = True,
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
""" FastAPI server for the Cognee API. """
|
||||
"""FastAPI server for the Cognee API."""
|
||||
|
||||
import os
|
||||
import uvicorn
|
||||
import logging
|
||||
|
|
@ -6,9 +7,26 @@ import sentry_sdk
|
|||
from fastapi import FastAPI, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from traceback import format_exc
|
||||
from cognee.api.v1.users.routers import (
|
||||
get_auth_router,
|
||||
get_register_router,
|
||||
get_reset_password_router,
|
||||
get_verify_router,
|
||||
get_users_router,
|
||||
get_visualize_router,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
|
|
@ -19,15 +37,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
if os.getenv("ENV", "prod") == "prod":
|
||||
sentry_sdk.init(
|
||||
dsn = os.getenv("SENTRY_REPORTING_URL"),
|
||||
traces_sample_rate = 1.0,
|
||||
profiles_sample_rate = 1.0,
|
||||
dsn=os.getenv("SENTRY_REPORTING_URL"),
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
app_environment = os.getenv("ENV", "prod")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# from cognee.modules.data.deletion import prune_system, prune_data
|
||||
|
|
@ -35,50 +53,42 @@ async def lifespan(app: FastAPI):
|
|||
# await prune_system(metadata = True)
|
||||
# if app_environment == "local" or app_environment == "dev":
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
await db_engine.create_database()
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
await get_default_user()
|
||||
|
||||
yield
|
||||
|
||||
app = FastAPI(debug = app_environment != "prod", lifespan = lifespan)
|
||||
|
||||
app = FastAPI(debug=app_environment != "prod", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins = ["*"],
|
||||
allow_credentials = True,
|
||||
allow_methods = ["OPTIONS", "GET", "POST", "DELETE"],
|
||||
allow_headers = ["*"],
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["OPTIONS", "GET", "POST", "DELETE"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
from cognee.api.v1.users.routers import get_auth_router, get_register_router,\
|
||||
get_reset_password_router, get_verify_router, get_users_router
|
||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
if request.url.path == "/api/v1/auth/login":
|
||||
return JSONResponse(
|
||||
status_code = 400,
|
||||
content = {"detail": "LOGIN_BAD_CREDENTIALS"},
|
||||
status_code=400,
|
||||
content={"detail": "LOGIN_BAD_CREDENTIALS"},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code = 400,
|
||||
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
status_code=400,
|
||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(CogneeApiError)
|
||||
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||
detail = {}
|
||||
|
|
@ -95,46 +105,42 @@ async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
|||
|
||||
# log the stack trace for easier serverside debugging
|
||||
logger.error(format_exc())
|
||||
return JSONResponse(
|
||||
status_code=status_code, content={"detail": detail["message"]}
|
||||
)
|
||||
return JSONResponse(status_code=status_code, content={"detail": detail["message"]})
|
||||
|
||||
app.include_router(
|
||||
get_auth_router(),
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"]
|
||||
)
|
||||
|
||||
app.include_router(get_auth_router(), prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
app.include_router(
|
||||
get_register_router(),
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
prefix="/api/v1/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_reset_password_router(),
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
prefix="/api/v1/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_verify_router(),
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
prefix="/api/v1/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix = "/api/v1/users",
|
||||
tags = ["users"],
|
||||
prefix="/api/v1/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_permissions_router(),
|
||||
prefix = "/api/v1/permissions",
|
||||
tags = ["permissions"],
|
||||
prefix="/api/v1/permissions",
|
||||
tags=["permissions"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""
|
||||
|
|
@ -148,37 +154,21 @@ def health_check():
|
|||
"""
|
||||
Health check endpoint that returns the server status.
|
||||
"""
|
||||
return Response(status_code = 200)
|
||||
return Response(status_code=200)
|
||||
|
||||
app.include_router(
|
||||
get_datasets_router(),
|
||||
prefix="/api/v1/datasets",
|
||||
tags=["datasets"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_add_router(),
|
||||
prefix="/api/v1/add",
|
||||
tags=["add"]
|
||||
)
|
||||
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
|
||||
|
||||
app.include_router(
|
||||
get_cognify_router(),
|
||||
prefix="/api/v1/cognify",
|
||||
tags=["cognify"]
|
||||
)
|
||||
app.include_router(get_add_router(), prefix="/api/v1/add", tags=["add"])
|
||||
|
||||
app.include_router(
|
||||
get_search_router(),
|
||||
prefix="/api/v1/search",
|
||||
tags=["search"]
|
||||
)
|
||||
app.include_router(get_cognify_router(), prefix="/api/v1/cognify", tags=["cognify"])
|
||||
|
||||
app.include_router(get_search_router(), prefix="/api/v1/search", tags=["search"])
|
||||
|
||||
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
|
||||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
||||
app.include_router(
|
||||
get_settings_router(),
|
||||
prefix="/api/v1/settings",
|
||||
tags=["settings"]
|
||||
)
|
||||
|
||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""
|
||||
|
|
@ -190,7 +180,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
|||
try:
|
||||
logger.info("Starting server at %s:%s", host, port)
|
||||
|
||||
uvicorn.run(app, host = host, port = port)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to start server: {e}")
|
||||
# Here you could add any cleanup code or error recovery code.
|
||||
|
|
|
|||
|
|
@ -14,10 +14,19 @@ from cognee.tasks.ingestion import get_dlt_destination
|
|||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
|
||||
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
|
||||
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||
|
||||
async def add(
|
||||
data: Union[BinaryIO, List[BinaryIO], str, List[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
):
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
|
|
@ -25,7 +34,9 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
|||
if "data://" in data:
|
||||
# data is a data directory path
|
||||
datasets = get_matched_datasets(data.replace("data://", ""), dataset_name)
|
||||
return await asyncio.gather(*[add(file_paths, dataset_name) for [dataset_name, file_paths] in datasets])
|
||||
return await asyncio.gather(
|
||||
*[add(file_paths, dataset_name) for [dataset_name, file_paths] in datasets]
|
||||
)
|
||||
|
||||
if "file://" in data:
|
||||
# data is a file path
|
||||
|
|
@ -37,7 +48,7 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
|||
return await add([file_path], dataset_name)
|
||||
|
||||
if hasattr(data, "file"):
|
||||
file_path = save_data_to_file(data.file, filename = data.filename)
|
||||
file_path = save_data_to_file(data.file, filename=data.filename)
|
||||
return await add([file_path], dataset_name)
|
||||
|
||||
# data is a list of file paths or texts
|
||||
|
|
@ -45,7 +56,7 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
|||
|
||||
for data_item in data:
|
||||
if hasattr(data_item, "file"):
|
||||
file_paths.append(save_data_to_file(data_item, filename = data_item.filename))
|
||||
file_paths.append(save_data_to_file(data_item, filename=data_item.filename))
|
||||
elif isinstance(data_item, str) and (
|
||||
data_item.startswith("/") or data_item.startswith("file://")
|
||||
):
|
||||
|
|
@ -58,10 +69,11 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
|||
|
||||
return []
|
||||
|
||||
|
||||
async def add_files(file_paths: List[str], dataset_name: str, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
||||
base_config = get_base_config()
|
||||
data_directory_path = base_config.data_root_directory
|
||||
|
||||
|
|
@ -72,7 +84,11 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
|
|||
|
||||
if data_directory_path not in file_path:
|
||||
file_name = file_path.split("/")[-1]
|
||||
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "")
|
||||
file_directory_path = (
|
||||
data_directory_path
|
||||
+ "/"
|
||||
+ (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "")
|
||||
)
|
||||
dataset_file_path = path.join(file_directory_path, file_name)
|
||||
|
||||
LocalStorage.ensure_directory_exists(file_directory_path)
|
||||
|
|
@ -85,16 +101,20 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
|
|||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
pipeline_name = "file_load_from_filesystem",
|
||||
destination = destination,
|
||||
pipeline_name="file_load_from_filesystem",
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
|
||||
dataset_name = (
|
||||
dataset_name.replace(" ", "_").replace(".", "_")
|
||||
if dataset_name is not None
|
||||
else "main_dataset"
|
||||
)
|
||||
|
||||
@dlt.resource(standalone = True, merge_key = "id")
|
||||
@dlt.resource(standalone=True, merge_key="id")
|
||||
async def data_resources(file_paths: str, user: User):
|
||||
for file_path in file_paths:
|
||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||
with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
data_id = ingestion.identify(classified_data)
|
||||
|
|
@ -109,9 +129,9 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
|
|||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
data = (await session.execute(
|
||||
select(Data).filter(Data.id == data_id)
|
||||
)).scalar_one_or_none()
|
||||
data = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if data is not None:
|
||||
data.name = file_metadata["name"]
|
||||
|
|
@ -123,11 +143,11 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
|
|||
await session.commit()
|
||||
else:
|
||||
data = Data(
|
||||
id = data_id,
|
||||
name = file_metadata["name"],
|
||||
raw_data_location = file_metadata["file_path"],
|
||||
extension = file_metadata["extension"],
|
||||
mime_type = file_metadata["mime_type"],
|
||||
id=data_id,
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
)
|
||||
|
||||
dataset.data.append(data)
|
||||
|
|
@ -144,14 +164,13 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
|
|||
await give_permission_on_document(user, data_id, "read")
|
||||
await give_permission_on_document(user, data_id, "write")
|
||||
|
||||
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
|
||||
run_info = pipeline.run(
|
||||
data_resources(processed_file_paths, user),
|
||||
table_name = "file_metadata",
|
||||
dataset_name = dataset_name,
|
||||
write_disposition = "merge",
|
||||
table_name="file_metadata",
|
||||
dataset_name=dataset_name,
|
||||
write_disposition="merge",
|
||||
)
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
|
||||
|
||||
return run_info
|
||||
|
|
|
|||
|
|
@ -3,22 +3,28 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines import run_tasks, Task
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata, resolve_data_directories
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
|
||||
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
|
||||
async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||
|
||||
async def add(
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
):
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories),
|
||||
Task(ingest_data_with_metadata, dataset_name, user)
|
||||
]
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data_with_metadata, dataset_name, user)]
|
||||
|
||||
pipeline = run_tasks(tasks, data, "add_pipeline")
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
print(result)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_add_router import get_add_router
|
||||
from .get_add_router import get_add_router
|
||||
|
|
|
|||
|
|
@ -11,17 +11,19 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_add_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
async def add(
|
||||
data: List[UploadFile],
|
||||
datasetId: str = Form(...),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
data: List[UploadFile],
|
||||
datasetId: str = Form(...),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
""" This endpoint is responsible for adding data to the graph."""
|
||||
"""This endpoint is responsible for adding data to the graph."""
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
||||
try:
|
||||
if isinstance(data, str) and data.startswith("http"):
|
||||
if "github" in data:
|
||||
|
|
@ -52,9 +54,6 @@ def get_add_router() -> APIRouter:
|
|||
user=user,
|
||||
)
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
)
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from cognee.infrastructure.databases.relational.user_authentication.users import authenticate_user_method
|
||||
from cognee.infrastructure.databases.relational.user_authentication.users import (
|
||||
authenticate_user_method,
|
||||
)
|
||||
|
||||
|
||||
async def authenticate_user(email: str, password: str):
|
||||
|
|
@ -11,6 +13,7 @@ async def authenticate_user(email: str, password: str):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Define an example user
|
||||
example_email = "example@example.com"
|
||||
example_password = "securepassword123"
|
||||
|
|
|
|||
|
|
@ -3,31 +3,30 @@ import logging
|
|||
from pathlib import Path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.embeddings import \
|
||||
get_embedding_engine
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||
from cognee.tasks.documents import (classify_documents,
|
||||
extract_chunks_from_documents)
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata
|
||||
from cognee.tasks.repo_processor import (enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
get_data_list_for_user,
|
||||
get_non_py_files,
|
||||
get_repo_file_dependencies)
|
||||
from cognee.tasks.repo_processor.get_source_code_chunks import \
|
||||
get_source_code_chunks
|
||||
from cognee.tasks.repo_processor import (
|
||||
enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
get_data_list_for_user,
|
||||
get_non_py_files,
|
||||
get_repo_file_dependencies,
|
||||
)
|
||||
from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_code, summarize_text
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
|
||||
from cognee.tasks.summarization import summarize_code, summarize_text
|
||||
|
||||
logger = logging.getLogger("code_graph_pipeline")
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
|
@ -42,9 +41,13 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
|
||||
data_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
|
|
@ -60,7 +63,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
Task(get_repo_file_dependencies),
|
||||
Task(enrich_dependency_graph),
|
||||
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
||||
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
|
||||
Task(
|
||||
get_source_code_chunks,
|
||||
embedding_model=embedding_engine.model,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(summarize_code, task_config={"batch_size": 50}),
|
||||
Task(add_data_points, task_config={"batch_size": 50}),
|
||||
]
|
||||
|
|
@ -72,17 +79,19 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents),
|
||||
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}),
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50}
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if include_docs:
|
||||
async for result in run_tasks(non_code_tasks, repo_path):
|
||||
yield result
|
||||
|
||||
async for result in run_tasks(tasks, repo_path, "cognify_code_pipeline"):
|
||||
yield result
|
||||
yield result
|
||||
|
|
|
|||
|
|
@ -10,18 +10,18 @@ from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
|||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import \
|
||||
get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import \
|
||||
log_pipeline_status
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.documents import (check_permissions_on_documents,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents)
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_documents,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
|
@ -31,7 +31,12 @@ logger = logging.getLogger("cognify.v2")
|
|||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None, graph_model: BaseModel = KnowledgeGraph):
|
||||
|
||||
async def cognify(
|
||||
datasets: Union[str, list[str]] = None,
|
||||
user: User = None,
|
||||
graph_model: BaseModel = KnowledgeGraph,
|
||||
):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -41,7 +46,7 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None, gra
|
|||
# If no datasets are provided, cognify all existing datasets.
|
||||
datasets = existing_datasets
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
if isinstance(datasets[0], str):
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
|
||||
existing_datasets_map = {
|
||||
|
|
@ -59,8 +64,10 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None, gra
|
|||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseModel = KnowledgeGraph):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
async def run_cognify_pipeline(
|
||||
dataset: Dataset, user: User, graph_model: BaseModel = KnowledgeGraph
|
||||
):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
||||
|
||||
document_ids_str = [str(document.id) for document in data_documents]
|
||||
|
||||
|
|
@ -69,32 +76,41 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
|
|||
|
||||
send_telemetry("cognee.cognify EXECUTION STARTED", user.id)
|
||||
|
||||
#async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
|
||||
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
if (
|
||||
dataset_id in task_status
|
||||
and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED
|
||||
):
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
await log_pipeline_status(
|
||||
dataset_id,
|
||||
PipelineRunStatus.DATASET_PROCESSING_STARTED,
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
|
||||
tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(extract_graph_from_data, graph_model = graph_model, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
task_config = { "batch_size": 10 }
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||
|
|
@ -106,17 +122,25 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
|
|||
|
||||
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
await log_pipeline_status(
|
||||
dataset_id,
|
||||
PipelineRunStatus.DATASET_PROCESSING_COMPLETED,
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
send_telemetry("cognee.cognify EXECUTION ERRORED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
await log_pipeline_status(
|
||||
dataset_id,
|
||||
PipelineRunStatus.DATASET_PROCESSING_ERRORED,
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_cognify_router import get_cognify_router
|
||||
|
|
|
|||
|
|
@ -7,23 +7,23 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
from fastapi import Depends
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
|
||||
class CognifyPayloadDTO(BaseModel):
|
||||
datasets: List[str]
|
||||
graph_model: Optional[BaseModel] = KnowledgeGraph
|
||||
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||
"""This endpoint is responsible for the cognitive processing of the content."""
|
||||
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
||||
|
||||
try:
|
||||
await cognee_cognify(payload.datasets, user, payload.graph_model)
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
)
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
""" This module is used to set the configuration of the system."""
|
||||
"""This module is used to set the configuration of the system."""
|
||||
|
||||
import os
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.exceptions import InvalidValueError, InvalidAttributeError
|
||||
|
|
@ -10,7 +11,8 @@ from cognee.infrastructure.llm.config import get_llm_config
|
|||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
class config():
|
||||
|
||||
class config:
|
||||
@staticmethod
|
||||
def system_root_directory(system_root_directory: str):
|
||||
databases_directory_path = os.path.join(system_root_directory, "databases")
|
||||
|
|
@ -39,12 +41,12 @@ class config():
|
|||
@staticmethod
|
||||
def set_classification_model(classification_model: object):
|
||||
cognify_config = get_cognify_config()
|
||||
cognify_config.classification_model = classification_model
|
||||
cognify_config.classification_model = classification_model
|
||||
|
||||
@staticmethod
|
||||
def set_summarization_model(summarization_model: object):
|
||||
cognify_config = get_cognify_config()
|
||||
cognify_config.summarization_model=summarization_model
|
||||
cognify_config.summarization_model = summarization_model
|
||||
|
||||
@staticmethod
|
||||
def set_graph_model(graph_model: object):
|
||||
|
|
@ -79,14 +81,16 @@ class config():
|
|||
@staticmethod
|
||||
def set_llm_config(config_dict: dict):
|
||||
"""
|
||||
Updates the llm config with values from config_dict.
|
||||
Updates the llm config with values from config_dict.
|
||||
"""
|
||||
llm_config = get_llm_config()
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(llm_config, key):
|
||||
object.__setattr__(llm_config, key, value)
|
||||
else:
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(
|
||||
message=f"'{key}' is not a valid attribute of the config."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_chunk_strategy(chunk_strategy: object):
|
||||
|
|
@ -108,7 +112,6 @@ class config():
|
|||
chunk_config = get_chunk_config()
|
||||
chunk_config.chunk_size = chunk_size
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_provider(vector_db_provider: str):
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
|
@ -117,33 +120,36 @@ class config():
|
|||
@staticmethod
|
||||
def set_relational_db_config(config_dict: dict):
|
||||
"""
|
||||
Updates the relational db config with values from config_dict.
|
||||
Updates the relational db config with values from config_dict.
|
||||
"""
|
||||
relational_db_config = get_relational_config()
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(relational_db_config, key):
|
||||
object.__setattr__(relational_db_config, key, value)
|
||||
else:
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(
|
||||
message=f"'{key}' is not a valid attribute of the config."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_config(config_dict: dict):
|
||||
"""
|
||||
Updates the vector db config with values from config_dict.
|
||||
Updates the vector db config with values from config_dict.
|
||||
"""
|
||||
vector_db_config = get_vectordb_config()
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(vector_db_config, key):
|
||||
object.__setattr__(vector_db_config, key, value)
|
||||
else:
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(
|
||||
message=f"'{key}' is not a valid attribute of the config."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_key(db_key: str):
|
||||
vector_db_config = get_vectordb_config()
|
||||
vector_db_config.vector_db_key = db_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_url(db_url: str):
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
|
@ -154,7 +160,9 @@ class config():
|
|||
base_config = get_base_config()
|
||||
|
||||
if "username" not in graphistry_config or "password" not in graphistry_config:
|
||||
raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.")
|
||||
raise InvalidValueError(
|
||||
message="graphistry_config dictionary must contain 'username' and 'password' keys."
|
||||
)
|
||||
|
||||
base_config.graphistry_username = graphistry_config.get("username")
|
||||
base_config.graphistry_password = graphistry_config.get("password")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.ingestion import discover_directory_datasets
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
class datasets():
|
||||
|
||||
|
||||
class datasets:
|
||||
@staticmethod
|
||||
async def list_datasets():
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
|
||||
user = await get_default_user()
|
||||
return await get_datasets(user.id)
|
||||
|
||||
|
|
@ -15,6 +18,7 @@ class datasets():
|
|||
@staticmethod
|
||||
async def list_data(dataset_id: str):
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
|
@ -28,6 +32,7 @@ class datasets():
|
|||
@staticmethod
|
||||
async def delete_dataset(dataset_id: str):
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
|
||||
user = await get_default_user()
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_datasets_router import get_datasets_router
|
||||
from .get_datasets_router import get_datasets_router
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ from cognee.modules.pipelines.models import PipelineRunStatus
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorResponseDTO(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class DatasetDTO(OutDTO):
|
||||
id: UUID
|
||||
name: str
|
||||
|
|
@ -26,6 +28,7 @@ class DatasetDTO(OutDTO):
|
|||
updated_at: Optional[datetime] = None
|
||||
owner_id: UUID
|
||||
|
||||
|
||||
class DataDTO(OutDTO):
|
||||
id: UUID
|
||||
name: str
|
||||
|
|
@ -35,6 +38,7 @@ class DataDTO(OutDTO):
|
|||
mime_type: str
|
||||
raw_data_location: str
|
||||
|
||||
|
||||
def get_datasets_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -42,46 +46,51 @@ def get_datasets_router() -> APIRouter:
|
|||
async def get_datasets(user: User = Depends(get_authenticated_user)):
|
||||
try:
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
|
||||
datasets = await get_datasets(user.id)
|
||||
|
||||
return datasets
|
||||
except Exception as error:
|
||||
logger.error(f"Error retrieving datasets: {str(error)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error retrieving datasets: {str(error)}") from error
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error retrieving datasets: {str(error)}"
|
||||
) from error
|
||||
|
||||
@router.delete("/{dataset_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}})
|
||||
@router.delete(
|
||||
"/{dataset_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}}
|
||||
)
|
||||
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Dataset ({dataset_id}) not found."
|
||||
)
|
||||
raise EntityNotFoundError(message=f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
await delete_dataset(dataset)
|
||||
|
||||
@router.delete("/{dataset_id}/data/{data_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}})
|
||||
async def delete_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
|
||||
@router.delete(
|
||||
"/{dataset_id}/data/{data_id}",
|
||||
response_model=None,
|
||||
responses={404: {"model": ErrorResponseDTO}},
|
||||
)
|
||||
async def delete_data(
|
||||
dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)
|
||||
):
|
||||
from cognee.modules.data.methods import get_data, delete_data
|
||||
from cognee.modules.data.methods import get_dataset
|
||||
|
||||
# Check if user has permission to access dataset and data by trying to get the dataset
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
#TODO: Handle situation differently if user doesn't have permission to access data?
|
||||
# TODO: Handle situation differently if user doesn't have permission to access data?
|
||||
if dataset is None:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Dataset ({dataset_id}) not found."
|
||||
)
|
||||
raise EntityNotFoundError(message=f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
data = await get_data(user.id, data_id)
|
||||
|
||||
if data is None:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Data ({data_id}) not found."
|
||||
)
|
||||
raise EntityNotFoundError(message=f"Data ({data_id}) not found.")
|
||||
|
||||
await delete_data(data)
|
||||
|
||||
|
|
@ -98,14 +107,18 @@ def get_datasets_router() -> APIRouter:
|
|||
status_code=200,
|
||||
content=str(graph_url),
|
||||
)
|
||||
except:
|
||||
except Exception as error:
|
||||
print(error)
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content="Graphistry credentials are not set. Please set them in your .env file.",
|
||||
)
|
||||
|
||||
@router.get("/{dataset_id}/data", response_model=list[DataDTO],
|
||||
responses={404: {"model": ErrorResponseDTO}})
|
||||
@router.get(
|
||||
"/{dataset_id}/data",
|
||||
response_model=list[DataDTO],
|
||||
responses={404: {"model": ErrorResponseDTO}},
|
||||
)
|
||||
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.data.methods import get_dataset_data, get_dataset
|
||||
|
||||
|
|
@ -125,8 +138,10 @@ def get_datasets_router() -> APIRouter:
|
|||
return dataset_data
|
||||
|
||||
@router.get("/status", response_model=dict[str, PipelineRunStatus])
|
||||
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None,
|
||||
user: User = Depends(get_authenticated_user)):
|
||||
async def get_dataset_status(
|
||||
datasets: Annotated[List[str], Query(alias="dataset")] = None,
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
||||
|
||||
try:
|
||||
|
|
@ -134,13 +149,12 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
return datasets_statuses
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
)
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
|
||||
async def get_raw_data(
|
||||
dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)
|
||||
):
|
||||
from cognee.modules.data.methods import get_data
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
|
||||
|
|
@ -148,10 +162,7 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
if dataset is None:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"detail": f"Dataset ({dataset_id}) not found."
|
||||
}
|
||||
status_code=404, content={"detail": f"Dataset ({dataset_id}) not found."}
|
||||
)
|
||||
|
||||
dataset_data = await get_dataset_data(dataset.id)
|
||||
|
|
@ -163,13 +174,17 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
# Check if matching_data contains an element
|
||||
if len(matching_data) == 0:
|
||||
raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).")
|
||||
raise EntityNotFoundError(
|
||||
message=f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||
)
|
||||
|
||||
data = await get_data(user.id, data_id)
|
||||
|
||||
if data is None:
|
||||
raise EntityNotFoundError(message=f"Data ({data_id}) not found in dataset ({dataset_id}).")
|
||||
raise EntityNotFoundError(
|
||||
message=f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||
)
|
||||
|
||||
return data.raw_data_location
|
||||
|
||||
return router
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_permissions_router import get_permissions_router
|
||||
from .get_permissions_router import get_permissions_router
|
||||
|
|
|
|||
|
|
@ -10,40 +10,56 @@ from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundErro
|
|||
from cognee.modules.users import get_user_db
|
||||
from cognee.modules.users.models import User, Group, Permission, UserGroup, GroupPermission
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
permissions_router = APIRouter()
|
||||
|
||||
@permissions_router.post("/groups/{group_id}/permissions")
|
||||
async def give_permission_to_group(group_id: str, permission: str, db: Session = Depends(get_user_db)):
|
||||
group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
async def give_permission_to_group(
|
||||
group_id: str, permission: str, db: Session = Depends(get_user_db)
|
||||
):
|
||||
group = (
|
||||
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
)
|
||||
|
||||
if not group:
|
||||
raise GroupNotFoundError
|
||||
|
||||
permission_entity = (
|
||||
await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first()
|
||||
(await db.session.execute(select(Permission).where(Permission.name == permission)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission_entity:
|
||||
stmt = insert(Permission).values(name=permission)
|
||||
await db.session.execute(stmt)
|
||||
permission_entity = (
|
||||
await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first()
|
||||
(await db.session.execute(select(Permission).where(Permission.name == permission)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
# add permission to group
|
||||
await db.session.execute(
|
||||
insert(GroupPermission).values(group_id=group.id, permission_id=permission_entity.id))
|
||||
except IntegrityError as e:
|
||||
insert(GroupPermission).values(
|
||||
group_id=group.id, permission_id=permission_entity.id
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Group permission already exists.")
|
||||
|
||||
await db.session.commit()
|
||||
|
||||
return JSONResponse(status_code = 200, content = {"message": "Permission assigned to group"})
|
||||
return JSONResponse(status_code=200, content={"message": "Permission assigned to group"})
|
||||
|
||||
@permissions_router.post("/users/{user_id}/groups")
|
||||
async def add_user_to_group(user_id: str, group_id: str, db: Session = Depends(get_user_db)):
|
||||
user = (await db.session.execute(select(User).where(User.id == user_id))).scalars().first()
|
||||
group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
group = (
|
||||
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
|
|
@ -54,11 +70,11 @@ def get_permissions_router() -> APIRouter:
|
|||
# Add association directly to the association table
|
||||
stmt = insert(UserGroup).values(user_id=user_id, group_id=group_id)
|
||||
await db.session.execute(stmt)
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User is already part of group.")
|
||||
|
||||
await db.session.commit()
|
||||
|
||||
return JSONResponse(status_code = 200, content = {"message": "User added to group"})
|
||||
return JSONResponse(status_code=200, content={"message": "User added to group"})
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from cognee.modules.data.deletion import prune_system, prune_data
|
||||
|
||||
class prune():
|
||||
|
||||
class prune:
|
||||
@staticmethod
|
||||
async def prune_data():
|
||||
await prune_data()
|
||||
|
||||
@staticmethod
|
||||
async def prune_system(graph = True, vector = True, metadata = False):
|
||||
async def prune_system(graph=True, vector=True, metadata=False):
|
||||
await prune_system(graph, vector, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
await prune.prune_data()
|
||||
await prune.prune_system()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from cognee.modules.search.operations import get_history
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def get_search_history(user: User = None) -> list:
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
|
||||
return await get_history(user.id)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_search_router import get_search_router
|
||||
from .get_search_router import get_search_router
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class SearchPayloadDTO(InDTO):
|
|||
search_type: SearchType
|
||||
query: str
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -22,21 +23,18 @@ def get_search_router() -> APIRouter:
|
|||
user: str
|
||||
created_at: datetime
|
||||
|
||||
@router.get("/", response_model = list[SearchHistoryItem])
|
||||
@router.get("/", response_model=list[SearchHistoryItem])
|
||||
async def get_search_history(user: User = Depends(get_authenticated_user)):
|
||||
try:
|
||||
history = await get_history(user.id)
|
||||
|
||||
return history
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code = 500,
|
||||
content = {"error": str(error)}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("/", response_model = list)
|
||||
@router.post("/", response_model=list)
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
""" This endpoint is responsible for searching for nodes in the graph."""
|
||||
"""This endpoint is responsible for searching for nodes in the graph."""
|
||||
from cognee.api.v1.search import search as cognee_search
|
||||
|
||||
try:
|
||||
|
|
@ -44,9 +42,6 @@ def get_search_router() -> APIRouter:
|
|||
|
||||
return results
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code = 409,
|
||||
content = {"error": str(error)}
|
||||
)
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
""" This module contains the search function that is used to search for nodes in the graph."""
|
||||
"""This module contains the search function that is used to search for nodes in the graph."""
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Callable, List
|
||||
|
|
@ -16,6 +17,7 @@ from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
ADJACENT = "ADJACENT"
|
||||
TRAVERSE = "TRAVERSE"
|
||||
|
|
@ -23,7 +25,7 @@ class SearchType(Enum):
|
|||
SUMMARY = "SUMMARY"
|
||||
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
|
||||
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
|
||||
DOCUMENT_CLASSIFICATION = "DOCUMENT_CLASSIFICATION",
|
||||
DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",)
|
||||
CYPHER = "CYPHER"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -33,12 +35,13 @@ class SearchType(Enum):
|
|||
except KeyError as error:
|
||||
raise ValueError(f"{name} is not a valid SearchType") from error
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
search_type: SearchType
|
||||
params: Dict[str, Any]
|
||||
|
||||
@field_validator("search_type", mode="before")
|
||||
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
|
||||
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
|
||||
if isinstance(value, str):
|
||||
return SearchType.from_str(value)
|
||||
return value
|
||||
|
|
@ -52,7 +55,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
|||
raise UserNotFoundError
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
search_params = SearchParameters(search_type = search_type, params = params)
|
||||
search_params = SearchParameters(search_type=search_type, params=params)
|
||||
search_results = await specific_search([search_params], user)
|
||||
|
||||
from uuid import UUID
|
||||
|
|
@ -61,7 +64,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
|||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if type(document_id) == str else document_id
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
|
|
|||
|
|
@ -16,14 +16,20 @@ from cognee.tasks.graph import query_graph_connections
|
|||
from cognee.tasks.summarization import query_summaries
|
||||
from cognee.tasks.completion import query_completion
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
SUMMARIES = "SUMMARIES"
|
||||
INSIGHTS = "INSIGHTS"
|
||||
CHUNKS = "CHUNKS"
|
||||
COMPLETION = "COMPLETION"
|
||||
|
||||
async def search(query_type: SearchType, query_text: str, user: User = None,
|
||||
datasets: Union[list[str], str, None] = None) -> list:
|
||||
|
||||
async def search(
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User = None,
|
||||
datasets: Union[list[str], str, None] = None,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
datasets = [datasets]
|
||||
|
|
@ -43,15 +49,16 @@ async def search(query_type: SearchType, query_text: str, user: User = None,
|
|||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if type(document_id) == str else document_id
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
||||
await log_result(query.id, json.dumps(filtered_search_results, cls = JSONEncoder), user.id)
|
||||
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
|
||||
|
||||
return filtered_search_results
|
||||
|
||||
|
||||
async def specific_search(query_type: SearchType, query: str, user) -> list:
|
||||
search_tasks: Dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: query_summaries,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .get_settings_router import get_settings_router
|
||||
from .get_settings_router import get_settings_router
|
||||
|
|
|
|||
|
|
@ -6,40 +6,50 @@ from fastapi import Depends
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.settings.get_settings import LLMConfig, VectorDBConfig
|
||||
|
||||
|
||||
class LLMConfigOutputDTO(OutDTO, LLMConfig):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBConfigOutputDTO(OutDTO, VectorDBConfig):
|
||||
pass
|
||||
|
||||
|
||||
class SettingsDTO(OutDTO):
|
||||
llm: LLMConfigOutputDTO
|
||||
vector_db: VectorDBConfigOutputDTO
|
||||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
||||
class VectorDBConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
|
||||
url: str
|
||||
api_key: str
|
||||
|
||||
|
||||
class SettingsPayloadDTO(InDTO):
|
||||
llm: Optional[LLMConfigInputDTO] = None
|
||||
vector_db: Optional[VectorDBConfigInputDTO] = None
|
||||
|
||||
|
||||
def get_settings_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=SettingsDTO)
|
||||
async def get_settings(user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.settings import get_settings as get_cognee_settings
|
||||
|
||||
return get_cognee_settings()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
async def save_settings(new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
async def save_settings(
|
||||
new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)
|
||||
):
|
||||
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
||||
|
||||
if new_settings.llm is not None:
|
||||
|
|
@ -48,4 +58,4 @@ def get_settings_router() -> APIRouter:
|
|||
if new_settings.vector_db is not None:
|
||||
await save_vector_db_config(new_settings.vector_db)
|
||||
|
||||
return router
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ from cognee.modules.users.methods import create_user as create_user_method
|
|||
|
||||
async def create_user(email: str, password: str, is_superuser: bool = False):
|
||||
user = await create_user_method(
|
||||
email = email,
|
||||
password = password,
|
||||
is_superuser = is_superuser,
|
||||
is_verified = True,
|
||||
email=email,
|
||||
password=password,
|
||||
is_superuser=is_superuser,
|
||||
is_verified=True,
|
||||
)
|
||||
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ from .get_register_router import get_register_router
|
|||
from .get_reset_password_router import get_reset_password_router
|
||||
from .get_users_router import get_users_router
|
||||
from .get_verify_router import get_verify_router
|
||||
from .get_visualize_router import get_visualize_router
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||
from cognee.modules.users.authentication.get_auth_backend import get_auth_backend
|
||||
|
||||
|
||||
def get_auth_router():
|
||||
auth_backend = get_auth_backend()
|
||||
return get_fastapi_users().get_auth_router(auth_backend)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||
from cognee.modules.users.models.User import UserRead, UserCreate
|
||||
|
||||
|
||||
def get_register_router():
|
||||
return get_fastapi_users().get_register_router(UserRead, UserCreate)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||
|
||||
|
||||
def get_reset_password_router():
|
||||
return get_fastapi_users().get_reset_password_router()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||
from cognee.modules.users.models.User import UserRead, UserUpdate
|
||||
|
||||
|
||||
def get_users_router():
|
||||
return get_fastapi_users().get_users_router(UserRead, UserUpdate)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||
from cognee.modules.users.models.User import UserRead
|
||||
|
||||
|
||||
def get_verify_router():
|
||||
return get_fastapi_users().get_verify_router(UserRead)
|
||||
|
|
|
|||
32
cognee/api/v1/users/routers/get_visualize_router.py
Normal file
32
cognee/api/v1/users/routers/get_visualize_router.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from fastapi import Form, UploadFile, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter
|
||||
from typing import List
|
||||
import aiohttp
|
||||
import subprocess
|
||||
import logging
|
||||
import os
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_visualize_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
async def visualize(
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""This endpoint is responsible for adding data to the graph."""
|
||||
from cognee.api.v1.visualize import visualize_graph
|
||||
|
||||
try:
|
||||
html_visualization = await visualize_graph()
|
||||
return html_visualization
|
||||
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
1
cognee/api/v1/visualize/__init__.py
Normal file
1
cognee/api/v1/visualize/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .visualize import visualize_graph
|
||||
14
cognee/api/v1/visualize/visualize.py
Normal file
14
cognee/api/v1/visualize/visualize.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from cognee.shared.utils import create_cognee_style_network_with_logo, graph_to_tuple
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
import logging
|
||||
|
||||
|
||||
async def visualize_graph(label: str = "name"):
|
||||
""" """
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
logging.info(graph_data)
|
||||
|
||||
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
||||
|
||||
return graph
|
||||
|
|
@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||
from cognee.root_dir import get_absolute_path
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
|
||||
|
||||
class BaseConfig(BaseSettings):
|
||||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
monitoring_tool: object = MonitoringTool.LANGFUSE
|
||||
|
|
@ -13,8 +14,8 @@ class BaseConfig(BaseSettings):
|
|||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
|
||||
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -22,6 +23,7 @@ class BaseConfig(BaseSettings):
|
|||
"monitoring_tool": self.monitoring_tool,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_base_config():
|
||||
return BaseConfig()
|
||||
|
|
|
|||
|
|
@ -10,4 +10,4 @@ from .exceptions import (
|
|||
ServiceError,
|
||||
InvalidValueError,
|
||||
InvalidAttributeError,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CogneeApiError(Exception):
|
||||
"""Base exception class"""
|
||||
|
||||
|
|
@ -36,19 +37,19 @@ class ServiceError(CogneeApiError):
|
|||
|
||||
class InvalidValueError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid Value.",
|
||||
name: str = "InvalidValueError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
self,
|
||||
message: str = "Invalid Value.",
|
||||
name: str = "InvalidValueError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class InvalidAttributeError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid attribute.",
|
||||
name: str = "InvalidAttributeError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
self,
|
||||
message: str = "Invalid attribute.",
|
||||
name: str = "InvalidAttributeError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ sys.path.insert(0, parent_dir)
|
|||
environment = os.getenv("AWS_ENV", "dev")
|
||||
|
||||
|
||||
def fetch_secret(secret_name:str, region_name:str, env_file_path:str):
|
||||
def fetch_secret(secret_name: str, region_name: str, env_file_path: str):
|
||||
"""Fetch the secret from AWS Secrets Manager and write it to the .env file."""
|
||||
print("Initializing session")
|
||||
session = boto3.session.Session()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
""" Chunking strategies for splitting text into smaller parts."""
|
||||
"""Chunking strategies for splitting text into smaller parts."""
|
||||
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
|
|
@ -6,17 +7,15 @@ from cognee.shared.data_models import ChunkStrategy
|
|||
|
||||
# /Users/vasa/Projects/cognee/cognee/infrastructure/data/chunking/DefaultChunkEngine.py
|
||||
|
||||
class DefaultChunkEngine():
|
||||
|
||||
class DefaultChunkEngine:
|
||||
def __init__(self, chunk_strategy=None, chunk_size=None, chunk_overlap=None):
|
||||
self.chunk_strategy = chunk_strategy
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> list[str]:
|
||||
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
|
|
@ -32,13 +31,12 @@ class DefaultChunkEngine():
|
|||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
|
||||
def chunk_data(self,
|
||||
chunk_strategy = None,
|
||||
source_data = None,
|
||||
chunk_size = None,
|
||||
chunk_overlap = None,
|
||||
def chunk_data(
|
||||
self,
|
||||
chunk_strategy=None,
|
||||
source_data=None,
|
||||
chunk_size=None,
|
||||
chunk_overlap=None,
|
||||
):
|
||||
"""
|
||||
Chunk data based on the specified strategy.
|
||||
|
|
@ -54,44 +52,47 @@ class DefaultChunkEngine():
|
|||
"""
|
||||
|
||||
if self.chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data, chunk_number = self.chunk_data_by_paragraph(source_data,chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||
chunked_data, chunk_number = self.chunk_data_by_paragraph(
|
||||
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
||||
)
|
||||
elif self.chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data, chunk_number = self.chunk_by_sentence(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||
chunked_data, chunk_number = self.chunk_by_sentence(
|
||||
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
||||
)
|
||||
elif self.chunk_strategy == ChunkStrategy.EXACT:
|
||||
chunked_data, chunk_number = self.chunk_data_exact(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||
chunked_data, chunk_number = self.chunk_data_exact(
|
||||
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
||||
)
|
||||
else:
|
||||
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
|
||||
|
||||
|
||||
return chunked_data, chunk_number
|
||||
|
||||
|
||||
|
||||
def chunk_data_exact(self, data_chunks, chunk_size, chunk_overlap):
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
chunks.append(data[i : i + chunk_size])
|
||||
numbered_chunks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
numbered_chunk = [i + 1, chunk]
|
||||
numbered_chunks.append(numbered_chunk)
|
||||
return chunks, numbered_chunks
|
||||
|
||||
|
||||
|
||||
def chunk_by_sentence(self, data_chunks, chunk_size, chunk_overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentence_endings = r"(?<=[.!?…]) +"
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
for sentence in sentences:
|
||||
if len(sentence) > chunk_size:
|
||||
chunks = self.chunk_data_exact(data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
chunks = self.chunk_data_exact(
|
||||
data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
sentence_chunks.extend(chunks)
|
||||
else:
|
||||
sentence_chunks.append(sentence)
|
||||
|
|
@ -102,9 +103,7 @@ class DefaultChunkEngine():
|
|||
numbered_chunks.append(numbered_chunk)
|
||||
return sentence_chunks, numbered_chunks
|
||||
|
||||
|
||||
|
||||
def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound = 0.75):
|
||||
def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound=0.75):
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkE
|
|||
from cognee.shared.data_models import ChunkStrategy
|
||||
|
||||
|
||||
|
||||
class LangchainChunkEngine:
|
||||
def __init__(self, chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
self.chunk_strategy = chunk_strategy
|
||||
|
|
@ -13,13 +12,12 @@ class LangchainChunkEngine:
|
|||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
|
||||
|
||||
def chunk_data(self,
|
||||
chunk_strategy = None,
|
||||
source_data = None,
|
||||
chunk_size = None,
|
||||
chunk_overlap = None,
|
||||
def chunk_data(
|
||||
self,
|
||||
chunk_strategy=None,
|
||||
source_data=None,
|
||||
chunk_size=None,
|
||||
chunk_overlap=None,
|
||||
):
|
||||
"""
|
||||
Chunk data based on the specified strategy.
|
||||
|
|
@ -35,20 +33,24 @@ class LangchainChunkEngine:
|
|||
"""
|
||||
|
||||
if chunk_strategy == ChunkStrategy.CODE:
|
||||
chunked_data, chunk_number = self.chunk_data_by_code(source_data,self.chunk_size, self.chunk_overlap)
|
||||
chunked_data, chunk_number = self.chunk_data_by_code(
|
||||
source_data, self.chunk_size, self.chunk_overlap
|
||||
)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.LANGCHAIN_CHARACTER:
|
||||
chunked_data, chunk_number = self.chunk_data_by_character(source_data,self.chunk_size, self.chunk_overlap)
|
||||
chunked_data, chunk_number = self.chunk_data_by_character(
|
||||
source_data, self.chunk_size, self.chunk_overlap
|
||||
)
|
||||
else:
|
||||
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
|
||||
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
|
||||
return chunked_data, chunk_number
|
||||
|
||||
|
||||
def chunk_data_by_code(self, data_chunks, chunk_size, chunk_overlap= 10, language=None):
|
||||
def chunk_data_by_code(self, data_chunks, chunk_size, chunk_overlap=10, language=None):
|
||||
from langchain_text_splitters import (
|
||||
Language,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
if language is None:
|
||||
language = Language.PYTHON
|
||||
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
|
|
@ -67,7 +69,10 @@ class LangchainChunkEngine:
|
|||
|
||||
def chunk_data_by_character(self, data_chunks, chunk_size=1500, chunk_overlap=10):
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size =chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
data_chunks = splitter.create_documents([data_chunks])
|
||||
|
||||
only_content = [chunk.page_content for chunk in data_chunks]
|
||||
|
|
@ -78,4 +83,3 @@ class LangchainChunkEngine:
|
|||
numbered_chunks.append(numbered_chunk)
|
||||
|
||||
return only_content, numbered_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ class ChunkConfig(BaseSettings):
|
|||
chunk_strategy: object = ChunkStrategy.PARAGRAPH
|
||||
chunk_engine: object = ChunkEngine.DEFAULT_ENGINE
|
||||
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ class ChunkingConfig(Dict):
|
|||
vector_db_key: str
|
||||
vector_db_provider: str
|
||||
|
||||
|
||||
def create_chunking_engine(config: ChunkingConfig):
|
||||
if config["chunk_engine"] == ChunkEngine.LANGCHAIN_ENGINE:
|
||||
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
|
||||
|
||||
return LangchainChunkEngine(
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
|
|
|
|||
|
|
@ -2,5 +2,6 @@ from .config import get_chunk_config
|
|||
|
||||
from .create_chunking_engine import create_chunking_engine
|
||||
|
||||
|
||||
def get_chunk_engine():
|
||||
return create_chunking_engine(get_chunk_config().to_dict())
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.shared.utils import extract_pos_tags
|
||||
|
||||
|
||||
def extract_keywords(text: str) -> list[str]:
|
||||
if len(text) == 0:
|
||||
raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.")
|
||||
|
|
@ -14,9 +15,7 @@ def extract_keywords(text: str) -> list[str]:
|
|||
tfidf = vectorizer.fit_transform(nouns)
|
||||
|
||||
top_nouns = sorted(
|
||||
vectorizer.vocabulary_,
|
||||
key = lambda x: tfidf[0, vectorizer.vocabulary_[x]],
|
||||
reverse = True
|
||||
vectorizer.vocabulary_, key=lambda x: tfidf[0, vectorizer.vocabulary_[x]], reverse=True
|
||||
)
|
||||
|
||||
keywords = []
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
class EmbeddingException(Exception):
|
||||
"""Custom exception for handling embedding-related errors."""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ This module defines a set of exceptions for handling various database errors
|
|||
from .exceptions import (
|
||||
EntityNotFoundError,
|
||||
EntityAlreadyExistsError,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class EntityNotFoundError(CogneeApiError):
|
||||
"""Database returns nothing"""
|
||||
|
||||
|
|
@ -22,4 +23,4 @@ class EntityAlreadyExistsError(CogneeApiError):
|
|||
name: str = "EntityAlreadyExistsError",
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
""" This module contains the configuration for the graph database. """
|
||||
"""This module contains the configuration for the graph database."""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
|
@ -15,8 +15,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_database_password: str = ""
|
||||
graph_database_port: int = 123
|
||||
graph_file_path: str = os.path.join(
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||
graph_filename
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"), graph_filename
|
||||
)
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from .config import get_graph_config
|
||||
|
|
@ -26,7 +27,11 @@ def create_graph_engine() -> GraphDBInterface:
|
|||
config = get_graph_config()
|
||||
|
||||
if config.graph_database_provider == "neo4j":
|
||||
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||
if not (
|
||||
config.graph_database_url
|
||||
and config.graph_database_username
|
||||
and config.graph_database_password
|
||||
):
|
||||
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
|
@ -34,7 +39,7 @@ def create_graph_engine() -> GraphDBInterface:
|
|||
return Neo4jAdapter(
|
||||
graph_database_url=config.graph_database_url,
|
||||
graph_database_username=config.graph_database_username,
|
||||
graph_database_password=config.graph_database_password
|
||||
graph_database_password=config.graph_database_password,
|
||||
)
|
||||
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
|
|
@ -53,6 +58,7 @@ def create_graph_engine() -> GraphDBInterface:
|
|||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=config.graph_file_path)
|
||||
|
||||
return graph_client
|
||||
|
|
|
|||
|
|
@ -1,47 +1,35 @@
|
|||
from typing import Protocol, Optional, Dict, Any
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class GraphDBInterface(Protocol):
|
||||
@abstractmethod
|
||||
async def query(self, query: str, params: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_properties: dict
|
||||
): raise NotImplementedError
|
||||
async def add_node(self, node_id: str, node_properties: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_nodes(
|
||||
self,
|
||||
nodes: list[tuple[str, dict]]
|
||||
): raise NotImplementedError
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict]]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_node(
|
||||
self,
|
||||
node_id: str
|
||||
): raise NotImplementedError
|
||||
async def delete_node(self, node_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_nodes(
|
||||
self,
|
||||
node_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
async def delete_nodes(self, node_ids: list[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_node(
|
||||
self,
|
||||
node_id: str
|
||||
): raise NotImplementedError
|
||||
async def extract_node(self, node_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_nodes(
|
||||
self,
|
||||
node_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
async def extract_nodes(self, node_ids: list[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_edge(
|
||||
|
|
@ -49,21 +37,20 @@ class GraphDBInterface(Protocol):
|
|||
from_node: str,
|
||||
to_node: str,
|
||||
relationship_name: str,
|
||||
edge_properties: Optional[Dict[str, Any]] = None
|
||||
): raise NotImplementedError
|
||||
edge_properties: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_edges(
|
||||
self,
|
||||
edges: tuple[str, str, str, dict]
|
||||
): raise NotImplementedError
|
||||
async def add_edges(self, edges: tuple[str, str, str, dict]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_graph(
|
||||
self,
|
||||
): raise NotImplementedError
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph_data(
|
||||
self
|
||||
): raise NotImplementedError
|
||||
async def get_graph_data(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
""" Neo4j Adapter for Graph Database"""
|
||||
"""Neo4j Adapter for Graph Database"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
|
|
@ -13,6 +14,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInte
|
|||
|
||||
logger = logging.getLogger("Neo4jAdapter")
|
||||
|
||||
|
||||
class Neo4jAdapter(GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -23,8 +25,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
):
|
||||
self.driver = driver or AsyncGraphDatabase.driver(
|
||||
graph_database_url,
|
||||
auth = (graph_database_username, graph_database_password),
|
||||
max_connection_lifetime = 120
|
||||
auth=(graph_database_username, graph_database_password),
|
||||
max_connection_lifetime=120,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -39,11 +41,11 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters = params)
|
||||
result = await session.run(query, parameters=params)
|
||||
data = await result.data()
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
|
|
@ -53,7 +55,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
WHERE n.id = $node_id
|
||||
RETURN COUNT(n) > 0 AS node_exists
|
||||
""",
|
||||
{"node_id": node_id}
|
||||
{"node_id": node_id},
|
||||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
|
|
@ -83,15 +85,17 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
"""
|
||||
|
||||
nodes = [{
|
||||
"node_id": str(node.id),
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
} for node in nodes]
|
||||
nodes = [
|
||||
{
|
||||
"node_id": str(node.id),
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
results = await self.query(query, dict(nodes = nodes))
|
||||
results = await self.query(query, dict(nodes=nodes))
|
||||
return results
|
||||
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
results = await self.extract_nodes([node_id])
|
||||
|
||||
|
|
@ -103,9 +107,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
MATCH (node {id: id})
|
||||
RETURN node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
results = await self.query(query, params)
|
||||
|
||||
|
|
@ -115,7 +117,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
node_id = id.replace(":", "_")
|
||||
|
||||
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
|
||||
params = { "node_id": node_id }
|
||||
params = {"node_id": node_id}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
|
@ -125,9 +127,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
MATCH (node {id: id})
|
||||
DETACH DELETE node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
|
@ -157,21 +157,29 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
try:
|
||||
params = {
|
||||
"edges": [{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
} for edge in edges],
|
||||
"edges": [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
}
|
||||
for edge in edges
|
||||
],
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
return [result["edge_exists"] for result in results]
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
|
||||
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: UUID,
|
||||
to_node: UUID,
|
||||
relationship_name: str,
|
||||
edge_properties: Optional[Dict[str, Any]] = {},
|
||||
):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||
|
|
@ -186,12 +194,11 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
|
|
@ -201,22 +208,25 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN rel
|
||||
"""
|
||||
|
||||
edges = [{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
},
|
||||
} for edge in edges]
|
||||
edges = [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
},
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
try:
|
||||
results = await self.query(query, dict(edges = edges))
|
||||
results = await self.query(query, dict(edges=edges))
|
||||
return results
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
|
|
@ -225,9 +235,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN n, r, m
|
||||
"""
|
||||
|
||||
results = await self.query(query, dict(node_id = node_id))
|
||||
results = await self.query(query, dict(node_id=node_id))
|
||||
|
||||
return [(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]}) for result in results]
|
||||
return [
|
||||
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def get_disconnected_nodes(self) -> list[str]:
|
||||
# return await self.query(
|
||||
|
|
@ -267,7 +280,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
results = await self.query(query)
|
||||
return results[0]["ids"] if len(results) > 0 else []
|
||||
|
||||
|
||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
|
|
@ -279,9 +291,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
edge_label = edge_label,
|
||||
)
|
||||
node_id=node_id,
|
||||
edge_label=edge_label,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["predecessor"] for result in results]
|
||||
|
|
@ -295,8 +307,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
)
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["predecessor"] for result in results]
|
||||
|
|
@ -312,8 +324,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
edge_label = edge_label,
|
||||
node_id=node_id,
|
||||
edge_label=edge_label,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -328,14 +340,16 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
)
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["successor"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
predecessors, successors = await asyncio.gather(self.get_predecessors(node_id), self.get_successors(node_id))
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.get_predecessors(node_id), self.get_successors(node_id)
|
||||
)
|
||||
|
||||
return predecessors + successors
|
||||
|
||||
|
|
@ -352,52 +366,55 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id = str(node_id))),
|
||||
self.query(successors_query, dict(node_id = str(node_id))),
|
||||
self.query(predecessors_query, dict(node_id=str(node_id))),
|
||||
self.query(successors_query, dict(node_id=str(node_id))),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
||||
for neighbour in predecessors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
for neighbour in successors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
return connections
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = { "node_ids": node_ids }
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = { "node_ids": node_ids }
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
||||
async def delete_graph(self):
|
||||
query = """MATCH (node)
|
||||
DETACH DELETE node;"""
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
def serialize_properties(self, properties = dict()):
|
||||
def serialize_properties(self, properties=dict()):
|
||||
serialized_properties = {}
|
||||
|
||||
for property_key, property_value in properties.items():
|
||||
|
|
@ -414,22 +431,28 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
result = await self.query(query)
|
||||
|
||||
nodes = [(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
) for record in result]
|
||||
nodes = [
|
||||
(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
edges = [(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
) for record in result]
|
||||
edges = [
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
|
|
@ -446,7 +469,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
where_clauses = []
|
||||
for attribute, values in attribute_filters[0].items():
|
||||
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
|
||||
values_str = ", ".join(
|
||||
f"'{value}'" if isinstance(value, str) else str(value) for value in values
|
||||
)
|
||||
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||||
|
||||
where_clause = " AND ".join(where_clauses)
|
||||
|
|
@ -458,10 +483,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
result_nodes = await self.query(query_nodes)
|
||||
|
||||
nodes = [(
|
||||
record["id"],
|
||||
record["properties"],
|
||||
) for record in result_nodes]
|
||||
nodes = [
|
||||
(
|
||||
record["id"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result_nodes
|
||||
]
|
||||
|
||||
query_edges = f"""
|
||||
MATCH (n)-[r]->(m)
|
||||
|
|
@ -470,11 +498,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
result_edges = await self.query(query_edges)
|
||||
|
||||
edges = [(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
) for record in result_edges]
|
||||
edges = [
|
||||
(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result_edges
|
||||
]
|
||||
|
||||
return (nodes, edges)
|
||||
return (nodes, edges)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ from cognee.modules.storage.utils import JSONEncoder
|
|||
|
||||
logger = logging.getLogger("NetworkXAdapter")
|
||||
|
||||
|
||||
class NetworkXAdapter(GraphDBInterface):
|
||||
_instance = None
|
||||
graph = None # Class variable to store the singleton instance
|
||||
graph = None # Class variable to store the singleton instance
|
||||
|
||||
def __new__(cls, filename):
|
||||
if cls._instance is None:
|
||||
|
|
@ -27,12 +28,12 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
cls._instance.filename = filename
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, filename = "cognee_graph.pkl"):
|
||||
def __init__(self, filename="cognee_graph.pkl"):
|
||||
self.filename = filename
|
||||
|
||||
async def get_graph_data(self):
|
||||
await self.load_graph_from_file()
|
||||
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||
return (list(self.graph.nodes(data=True)), list(self.graph.edges(data=True, keys=True)))
|
||||
|
||||
async def query(self, query: str, params: dict):
|
||||
pass
|
||||
|
|
@ -57,24 +58,21 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.graph.add_nodes_from(nodes)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
||||
async def get_graph(self):
|
||||
return self.graph
|
||||
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
||||
return self.graph.has_edge(from_node, to_node, key=edge_label)
|
||||
|
||||
async def has_edges(self, edges):
|
||||
result = []
|
||||
|
||||
for (from_node, to_node, edge_label) in edges:
|
||||
for from_node, to_node, edge_label in edges:
|
||||
if self.graph.has_edge(from_node, to_node, edge_label):
|
||||
result.append((from_node, to_node, edge_label))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
|
|
@ -83,24 +81,38 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
edge_properties: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
edge_properties["updated_at"] = datetime.now(timezone.utc)
|
||||
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
|
||||
self.graph.add_edge(
|
||||
from_node,
|
||||
to_node,
|
||||
key=relationship_name,
|
||||
**(edge_properties if edge_properties else {}),
|
||||
)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def add_edges(
|
||||
self,
|
||||
edges: tuple[str, str, str, dict],
|
||||
) -> None:
|
||||
edges = [(edge[0], edge[1], edge[2], {
|
||||
**(edge[3] if len(edge) == 4 else {}),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}) for edge in edges]
|
||||
edges = [
|
||||
(
|
||||
edge[0],
|
||||
edge[1],
|
||||
edge[2],
|
||||
{
|
||||
**(edge[3] if len(edge) == 4 else {}),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
self.graph.add_edges_from(edges)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
|
||||
|
||||
return list(self.graph.in_edges(node_id, data=True)) + list(
|
||||
self.graph.out_edges(node_id, data=True)
|
||||
)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Asynchronously delete a node from the graph if it exists."""
|
||||
|
|
@ -112,12 +124,11 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.graph.remove_nodes_from(node_ids)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
||||
async def get_disconnected_nodes(self) -> List[str]:
|
||||
connected_components = list(nx.weakly_connected_components(self.graph))
|
||||
|
||||
disconnected_nodes = []
|
||||
biggest_subgraph = max(connected_components, key = len)
|
||||
biggest_subgraph = max(connected_components, key=len)
|
||||
|
||||
for component in connected_components:
|
||||
if component != biggest_subgraph:
|
||||
|
|
@ -125,7 +136,6 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return disconnected_nodes
|
||||
|
||||
|
||||
async def extract_node(self, node_id: str) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
return self.graph.nodes[node_id]
|
||||
|
|
@ -139,8 +149,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return [
|
||||
self.graph.nodes[predecessor] for predecessor \
|
||||
in list(self.graph.predecessors(node_id))
|
||||
self.graph.nodes[predecessor]
|
||||
for predecessor in list(self.graph.predecessors(node_id))
|
||||
]
|
||||
|
||||
nodes = []
|
||||
|
|
@ -155,8 +165,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return [
|
||||
self.graph.nodes[successor] for successor \
|
||||
in list(self.graph.successors(node_id))
|
||||
self.graph.nodes[successor]
|
||||
for successor in list(self.graph.successors(node_id))
|
||||
]
|
||||
|
||||
nodes = []
|
||||
|
|
@ -210,7 +220,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return connections
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
for predecessor_id in list(self.graph.predecessors(node_id)):
|
||||
|
|
@ -219,7 +231,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
for successor_id in list(self.graph.successors(node_id)):
|
||||
|
|
@ -228,7 +242,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def save_graph_to_file(self, file_path: str=None) -> None:
|
||||
async def save_graph_to_file(self, file_path: str = None) -> None:
|
||||
"""Asynchronously save the graph to a file in JSON format."""
|
||||
if not file_path:
|
||||
file_path = self.filename
|
||||
|
|
@ -236,8 +250,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as file:
|
||||
await file.write(json.dumps(graph_data, cls = JSONEncoder))
|
||||
|
||||
await file.write(json.dumps(graph_data, cls=JSONEncoder))
|
||||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
"""Asynchronously load the graph from a file in JSON format."""
|
||||
|
|
@ -252,50 +265,59 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
graph_data = json.loads(await file.read())
|
||||
for node in graph_data["nodes"]:
|
||||
try:
|
||||
node["id"] = UUID(node["id"])
|
||||
except:
|
||||
pass
|
||||
node["id"] = UUID(node["id"])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
if "updated_at" in node:
|
||||
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
node["updated_at"] = datetime.strptime(
|
||||
node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
|
||||
for edge in graph_data["links"]:
|
||||
try:
|
||||
source_id = UUID(edge["source"])
|
||||
target_id = UUID(edge["target"])
|
||||
source_id = UUID(edge["source"])
|
||||
target_id = UUID(edge["target"])
|
||||
|
||||
edge["source"] = source_id
|
||||
edge["target"] = target_id
|
||||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except:
|
||||
pass
|
||||
edge["source"] = source_id
|
||||
edge["target"] = target_id
|
||||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
if "updated_at" in edge:
|
||||
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
edge["updated_at"] = datetime.strptime(
|
||||
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
|
||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||
|
||||
for node_id, node_data in self.graph.nodes(data=True):
|
||||
node_data['id'] = node_id
|
||||
node_data["id"] = node_id
|
||||
else:
|
||||
# Log that the file does not exist and an empty graph is initialized
|
||||
logger.warning("File %s not found. Initializing an empty graph.", file_path)
|
||||
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
|
||||
self.graph = (
|
||||
nx.MultiDiGraph()
|
||||
) # Use MultiDiGraph to keep it consistent with __init__
|
||||
|
||||
file_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(file_dir):
|
||||
os.makedirs(file_dir, exist_ok = True)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
||||
except Exception:
|
||||
logger.error("Failed to load graph from file: %s", file_path)
|
||||
|
||||
|
||||
async def delete_graph(self, file_path: str = None):
|
||||
"""Asynchronously delete the graph file from the filesystem."""
|
||||
if file_path is None:
|
||||
file_path = self.filename # Assuming self.filename is defined elsewhere and holds the default graph file path
|
||||
file_path = (
|
||||
self.filename
|
||||
) # Assuming self.filename is defined elsewhere and holds the default graph file path
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
await aiofiles_os.remove(file_path)
|
||||
|
|
@ -305,7 +327,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
except Exception as error:
|
||||
logger.error("Failed to delete graph: %s", error)
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
||||
|
|
@ -325,18 +349,21 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
# Filter nodes
|
||||
filtered_nodes = [
|
||||
(node, data) for node, data in self.graph.nodes(data=True)
|
||||
(node, data)
|
||||
for node, data in self.graph.nodes(data=True)
|
||||
if all(data.get(attr) in values for attr, values in where_clauses)
|
||||
]
|
||||
|
||||
# Filter edges where both source and target nodes satisfy the filters
|
||||
filtered_edges = [
|
||||
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
|
||||
(source, target, data.get("relationship_type", "UNKNOWN"), data)
|
||||
for source, target, data in self.graph.edges(data=True)
|
||||
if (
|
||||
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
|
||||
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
|
||||
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses)
|
||||
and all(
|
||||
self.graph.nodes[target].get(attr) in values for attr, values in where_clauses
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
return filtered_nodes, filtered_edges
|
||||
return filtered_nodes, filtered_edges
|
||||
|
|
|
|||
|
|
@ -1,38 +1,36 @@
|
|||
import asyncio
|
||||
|
||||
# from datetime import datetime
|
||||
import json
|
||||
from textwrap import dedent
|
||||
from uuid import UUID
|
||||
from webbrowser import Error
|
||||
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import \
|
||||
GraphDBInterface
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import \
|
||||
VectorDBInterface
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
|
||||
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
database_port: int,
|
||||
embedding_engine = EmbeddingEngine,
|
||||
embedding_engine=EmbeddingEngine,
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host = database_url,
|
||||
port = database_port,
|
||||
host=database_url,
|
||||
port=database_port,
|
||||
)
|
||||
self.embedding_engine = embedding_engine
|
||||
self.graph_name = "cognee_graph"
|
||||
|
|
@ -56,7 +54,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
return f"'{str(value)}'"
|
||||
if type(value) is int or type(value) is float:
|
||||
return value
|
||||
if type(value) is list and type(value[0]) is float and len(value) == self.embedding_engine.get_vector_size():
|
||||
if (
|
||||
type(value) is list
|
||||
and type(value[0]) is float
|
||||
and len(value) == self.embedding_engine.get_vector_size()
|
||||
):
|
||||
return f"'vecf32({value})'"
|
||||
# if type(value) is datetime:
|
||||
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
|
@ -70,14 +72,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
node_label = type(data_point).__tablename__
|
||||
property_names = DataPoint.get_embeddable_property_names(data_point)
|
||||
|
||||
node_properties = await self.stringify_properties({
|
||||
**data_point.model_dump(),
|
||||
**({
|
||||
property_names[index]: (vectorized_values[index] \
|
||||
if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
|
||||
node_properties = await self.stringify_properties(
|
||||
{
|
||||
**data_point.model_dump(),
|
||||
**(
|
||||
{
|
||||
property_names[index]: (
|
||||
vectorized_values[index]
|
||||
if index < len(vectorized_values)
|
||||
else getattr(data_point, property_name, None)
|
||||
)
|
||||
for index, property_name in enumerate(property_names)
|
||||
}),
|
||||
})
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return dedent(f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
|
|
@ -129,12 +138,13 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
await self.create_data_point_query(
|
||||
data_point,
|
||||
[
|
||||
vectorized_values[vector_map[str(data_point.id)][property_name]] \
|
||||
if vector_map[str(data_point.id)][property_name] is not None \
|
||||
else None \
|
||||
vectorized_values[vector_map[str(data_point.id)][property_name]]
|
||||
if vector_map[str(data_point.id)][property_name] is not None
|
||||
else None
|
||||
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
||||
],
|
||||
) for data_point in data_points
|
||||
)
|
||||
for data_point in data_points
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
|
|
@ -144,17 +154,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
graph = self.driver.select_graph(self.graph_name)
|
||||
|
||||
if not self.has_vector_index(graph, index_name, index_property_name):
|
||||
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
|
||||
graph.create_node_vector_index(
|
||||
index_name, index_property_name, dim=self.embedding_engine.get_vector_size()
|
||||
)
|
||||
|
||||
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
|
||||
try:
|
||||
indices = graph.list_indices()
|
||||
|
||||
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
|
||||
except:
|
||||
return any(
|
||||
[
|
||||
(index[0] == index_name and index_property_name in index[1])
|
||||
for index in indices.result_set
|
||||
]
|
||||
)
|
||||
except Error as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
pass
|
||||
|
||||
async def add_node(self, node: DataPoint):
|
||||
|
|
@ -183,11 +203,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
""").strip()
|
||||
|
||||
params = {
|
||||
"edges": [{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
} for edge in edges],
|
||||
"edges": [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
}
|
||||
for edge in edges
|
||||
],
|
||||
}
|
||||
|
||||
results = self.query(query, params).result_set
|
||||
|
|
@ -196,7 +219,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
async def retrieve(self, data_point_ids: list[UUID]):
|
||||
result = self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||
"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||
{
|
||||
"node_ids": [str(data_point) for data_point in data_point_ids],
|
||||
},
|
||||
|
|
@ -224,19 +247,19 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id = node_id)),
|
||||
self.query(successors_query, dict(node_id = node_id)),
|
||||
self.query(predecessors_query, dict(node_id=node_id)),
|
||||
self.query(successors_query, dict(node_id=node_id)),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
||||
for neighbour in predecessors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
for neighbour in successors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
return connections
|
||||
|
||||
|
|
@ -279,12 +302,15 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
*[self.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = query_vector,
|
||||
limit = limit,
|
||||
with_vector = with_vectors,
|
||||
) for query_vector in query_vectors]
|
||||
*[
|
||||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def get_graph_data(self):
|
||||
|
|
@ -292,28 +318,34 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
result = self.query(query)
|
||||
|
||||
nodes = [(
|
||||
record[2]["id"],
|
||||
record[2],
|
||||
) for record in result.result_set]
|
||||
nodes = [
|
||||
(
|
||||
record[2]["id"],
|
||||
record[2],
|
||||
)
|
||||
for record in result.result_set
|
||||
]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = self.query(query)
|
||||
edges = [(
|
||||
record[3]["source_node_id"],
|
||||
record[3]["target_node_id"],
|
||||
record[2],
|
||||
record[3],
|
||||
) for record in result.result_set]
|
||||
edges = [
|
||||
(
|
||||
record[3]["source_node_id"],
|
||||
record[3]["target_node_id"],
|
||||
record[2],
|
||||
record[3],
|
||||
)
|
||||
for record in result.result_set
|
||||
]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
return self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
{
|
||||
"node_ids": [str(data_point) for data_point in data_point_ids],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@ from functools import lru_cache
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
||||
class RelationalConfig(BaseSettings):
|
||||
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
|
||||
db_name: str = "cognee_db"
|
||||
db_host: Union[str, None] = None # "localhost"
|
||||
db_port: Union[str, None] = None # "5432"
|
||||
db_username: Union[str, None] = None # "cognee"
|
||||
db_password: Union[str, None] = None # "cognee"
|
||||
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
|
||||
db_name: str = "cognee_db"
|
||||
db_host: Union[str, None] = None # "localhost"
|
||||
db_port: Union[str, None] = None # "5432"
|
||||
db_username: Union[str, None] = None # "cognee"
|
||||
db_password: Union[str, None] = None # "cognee"
|
||||
db_provider: str = "sqlite"
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -26,6 +27,7 @@ class RelationalConfig(BaseSettings):
|
|||
"db_provider": self.db_provider,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_relational_config():
|
||||
return RelationalConfig()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from cognee.infrastructure.files.storage import LocalStorage
|
|||
from .ModelBase import Base
|
||||
from .get_relational_engine import get_relational_engine, get_relational_config
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
relational_config = get_relational_config()
|
||||
relational_engine = get_relational_engine()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
|
||||
|
||||
def create_relational_engine(
|
||||
db_path: str,
|
||||
db_name: str,
|
||||
|
|
@ -13,6 +14,8 @@ def create_relational_engine(
|
|||
connection_string = f"sqlite+aiosqlite:///{db_path}/{db_name}"
|
||||
|
||||
if db_provider == "postgres":
|
||||
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
connection_string = (
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
return SQLAlchemyAdapter(connection_string)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@
|
|||
from .config import get_relational_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
# @lru_cache
|
||||
def get_relational_engine():
|
||||
relational_config = get_relational_config()
|
||||
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from os import path
|
||||
import logging
|
||||
import logging
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
|
|
@ -18,7 +18,8 @@ from ..ModelBase import Base
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLAlchemyAdapter():
|
||||
|
||||
class SQLAlchemyAdapter:
|
||||
def __init__(self, connection_string: str):
|
||||
self.db_path: str = None
|
||||
self.db_uri: str = connection_string
|
||||
|
|
@ -58,17 +59,23 @@ class SQLAlchemyAdapter():
|
|||
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
|
||||
async with self.engine.begin() as connection:
|
||||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
|
||||
await connection.execute(
|
||||
text(
|
||||
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
|
||||
)
|
||||
)
|
||||
await connection.close()
|
||||
|
||||
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
|
||||
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
|
||||
async with self.engine.begin() as connection:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
# SQLite doesn’t support schema namespaces and the CASCADE keyword.
|
||||
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
|
||||
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||
else:
|
||||
await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;"))
|
||||
await connection.execute(
|
||||
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
|
||||
)
|
||||
|
||||
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
|
||||
columns = ", ".join(data[0].keys())
|
||||
|
|
@ -94,7 +101,9 @@ class SQLAlchemyAdapter():
|
|||
return [schema[0] for schema in result.fetchall()]
|
||||
return []
|
||||
|
||||
async def delete_entity_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
|
||||
async def delete_entity_by_id(
|
||||
self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"
|
||||
):
|
||||
"""
|
||||
Delete entity in given table based on id. Table must have an id Column.
|
||||
"""
|
||||
|
|
@ -114,7 +123,6 @@ class SQLAlchemyAdapter():
|
|||
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def delete_data_entity(self, data_id: UUID):
|
||||
"""
|
||||
Delete data and local files related to data if there are no references to it anymore.
|
||||
|
|
@ -131,14 +139,19 @@ class SQLAlchemyAdapter():
|
|||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
|
||||
|
||||
# Check if other data objects point to the same raw data location
|
||||
raw_data_location_entities = (await session.execute(
|
||||
select(Data.raw_data_location).where(Data.raw_data_location == data_entity.raw_data_location))).all()
|
||||
raw_data_location_entities = (
|
||||
await session.execute(
|
||||
select(Data.raw_data_location).where(
|
||||
Data.raw_data_location == data_entity.raw_data_location
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
# Don't delete local file unless this is the only reference to the file in the database
|
||||
if len(raw_data_location_entities) == 1:
|
||||
|
||||
# delete local file only if it's created by cognee
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
config = get_base_config()
|
||||
|
||||
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
|
||||
|
|
@ -198,15 +211,18 @@ class SQLAlchemyAdapter():
|
|||
metadata.clear()
|
||||
return table_names
|
||||
|
||||
|
||||
async def get_data(self, table_name: str, filters: dict = None):
|
||||
async with self.engine.begin() as connection:
|
||||
query = f"SELECT * FROM {table_name}"
|
||||
if filters:
|
||||
filter_conditions = " AND ".join([
|
||||
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})" if isinstance(value, list)
|
||||
else f"{key} = :{key}" for key, value in filters.items()
|
||||
])
|
||||
filter_conditions = " AND ".join(
|
||||
[
|
||||
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
|
||||
if isinstance(value, list)
|
||||
else f"{key} = :{key}"
|
||||
for key, value in filters.items()
|
||||
]
|
||||
)
|
||||
query += f" WHERE {filter_conditions};"
|
||||
query = text(query)
|
||||
results = await connection.execute(query, filters)
|
||||
|
|
@ -252,7 +268,6 @@ class SQLAlchemyAdapter():
|
|||
except Exception as e:
|
||||
print(f"Error dropping database tables: {e}")
|
||||
|
||||
|
||||
async def create_database(self):
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
|
@ -264,7 +279,6 @@ class SQLAlchemyAdapter():
|
|||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def delete_database(self):
|
||||
try:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
|
|
@ -281,7 +295,9 @@ class SQLAlchemyAdapter():
|
|||
# Load the schema information into the MetaData object
|
||||
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||
for table in metadata.sorted_tables:
|
||||
drop_table_query = text(f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE")
|
||||
drop_table_query = text(
|
||||
f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE"
|
||||
)
|
||||
await connection.execute(drop_table_query)
|
||||
metadata.clear()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@ from functools import lru_cache
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
||||
class VectorConfig(BaseSettings):
|
||||
vector_db_url: str = os.path.join(
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||
"cognee.lancedb"
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"), "cognee.lancedb"
|
||||
)
|
||||
vector_db_port: int = 1234
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -22,6 +22,7 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_provider": self.vector_db_provider,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vectordb_config():
|
||||
return VectorConfig()
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(
|
||||
config["vector_db_url"],
|
||||
config["vector_db_key"],
|
||||
embedding_engine=embedding_engine
|
||||
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "qdrant":
|
||||
|
|
@ -30,10 +28,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
return QDrantAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
embedding_engine=embedding_engine
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif config['vector_db_provider'] == 'milvus':
|
||||
elif config["vector_db_provider"] == "milvus":
|
||||
from .milvus.MilvusAdapter import MilvusAdapter
|
||||
|
||||
if not config["vector_db_url"]:
|
||||
|
|
@ -41,11 +39,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
|
||||
return MilvusAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config['vector_db_key'],
|
||||
embedding_engine=embedding_engine
|
||||
api_key=config["vector_db_key"],
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Protocol
|
||||
|
||||
|
||||
class EmbeddingEngine(Protocol):
|
||||
async def embed_text(self, text: list[str]) -> list[list[float]]:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -43,14 +43,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def exponential_backoff(attempt):
|
||||
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
|
||||
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
if self.mock:
|
||||
response = {
|
||||
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
|
||||
}
|
||||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
||||
|
||||
self.retry_count = 0
|
||||
|
||||
|
|
@ -61,7 +59,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version
|
||||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
self.retry_count = 0
|
||||
|
|
@ -73,7 +71,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
if len(text) == 1:
|
||||
parts = [text]
|
||||
else:
|
||||
parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]]
|
||||
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]
|
||||
|
||||
parts_futures = [self.embed_text(part) for part in parts]
|
||||
embeddings = await asyncio.gather(*parts_futures)
|
||||
|
|
@ -89,7 +87,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
except litellm.exceptions.RateLimitError:
|
||||
if self.retry_count >= self.MAX_RETRIES:
|
||||
raise Exception(f"Rate limit exceeded and no more retries left.")
|
||||
raise Exception("Rate limit exceeded and no more retries left.")
|
||||
|
||||
await exponential_backoff(self.retry_count)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue