Merge branch 'dev' into COG-748

This commit is contained in:
Vasilije 2025-01-08 18:39:48 +01:00 committed by GitHub
commit 93bca8ee5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
349 changed files with 6418 additions and 4534 deletions

View file

@ -17,7 +17,7 @@ jobs:
publish_docker_to_ecr:
name: Publish Cognee Docker image
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
permissions:
id-token: write
contents: read

View file

@ -17,7 +17,7 @@ jobs:
publish_docker_to_ecr:
name: Publish Docker PromethAI image
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
permissions:
id-token: write
contents: read

View file

@ -9,7 +9,7 @@ jobs:
build_docker:
name: Build Cognee Backend Docker App Image
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Check out Cognee code
uses: actions/checkout@v3

View file

@ -4,7 +4,7 @@ on: [pull_request, issues]
jobs:
greeting:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/first-interaction@v1
with:

View file

@ -12,7 +12,7 @@ on:
jobs:
docker-compose-test:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout repository

View file

@ -7,7 +7,7 @@ on:
jobs:
docker-build-and-push:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout repository
@ -32,7 +32,7 @@ jobs:
run: |
IMAGE_NAME=cognee/cognee
TAG_VERSION="${BRANCH_NAME}-${COMMIT_SHA}"
echo "Building image: ${IMAGE_NAME}:${TAG_VERSION}"
docker buildx build \
--platform linux/amd64,linux/arm64 \

View file

@ -7,7 +7,7 @@ on:
jobs:
profiler:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
# Checkout the code from the repository with full history
@ -47,7 +47,7 @@ jobs:
python-version: '3.10'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true
@ -82,50 +82,50 @@ jobs:
poetry run pyinstrument --renderer json -o base_results.json cognee/api/v1/cognify/code_graph_pipeline.py
# Run profiler on head branch
- name: Run profiler on head branch
env:
HEAD_SHA: ${{ env.HEAD_SHA }}
run: |
echo "Profiling the head branch for code_graph_pipeline.py"
echo "Checking out head SHA: $HEAD_SHA"
git checkout $HEAD_SHA
echo "This is the working directory: $PWD"
# Ensure the script is executable
chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
# Run Scalene
poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
# Compare profiling results
- name: Compare profiling results
run: |
python -c '
import json
try:
with open("base_results.json") as f:
base = json.load(f)
with open("head_results.json") as f:
head = json.load(f)
cpu_diff = head.get("total_cpu_samples_python", 0) - base.get("total_cpu_samples_python", 0)
memory_diff = head.get("malloc_samples", 0) - base.get("malloc_samples", 0)
results = [
f"CPU Usage Difference: {cpu_diff}",
f"Memory Usage Difference: {memory_diff} bytes"
]
with open("profiling_diff.txt", "w") as f:
f.write("\\n".join(results) + "\\n")
print("\\n".join(results)) # Print results to terminal
except Exception as e:
error_message = f"Error comparing profiling results: {e}"
with open("profiling_diff.txt", "w") as f:
f.write(error_message + "\\n")
print(error_message) # Print error to terminal
'
- name: Upload profiling diff artifact
uses: actions/upload-artifact@v3
with:
name: profiling-diff
path: profiling_diff.txt
# - name: Run profiler on head branch
# env:
# HEAD_SHA: ${{ env.HEAD_SHA }}
# run: |
# echo "Profiling the head branch for code_graph_pipeline.py"
# echo "Checking out head SHA: $HEAD_SHA"
# git checkout $HEAD_SHA
# echo "This is the working directory: $PWD"
# # Ensure the script is executable
# chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
# # Run Scalene
# poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
#
# # Compare profiling results
# - name: Compare profiling results
# run: |
# python -c '
# import json
# try:
# with open("base_results.json") as f:
# base = json.load(f)
# with open("head_results.json") as f:
# head = json.load(f)
# cpu_diff = head.get("total_cpu_samples_python", 0) - base.get("total_cpu_samples_python", 0)
# memory_diff = head.get("malloc_samples", 0) - base.get("malloc_samples", 0)
# results = [
# f"CPU Usage Difference: {cpu_diff}",
# f"Memory Usage Difference: {memory_diff} bytes"
# ]
# with open("profiling_diff.txt", "w") as f:
# f.write("\\n".join(results) + "\\n")
# print("\\n".join(results)) # Print results to terminal
# except Exception as e:
# error_message = f"Error comparing profiling results: {e}"
# with open("profiling_diff.txt", "w") as f:
# f.write(error_message + "\\n")
# print(error_message) # Print error to terminal
# '
#
# - name: Upload profiling diff artifact
# uses: actions/upload-artifact@v3
# with:
# name: profiling-diff
# path: profiling_diff.txt
# Post results to the pull request
# - name: Post profiling results to PR

View file

@ -16,8 +16,8 @@ jobs:
fail-fast: true
matrix:
os:
- ubuntu-latest
python-version: ["3.8.x", "3.9.x", "3.10.x", "3.11.x"]
- ubuntu-22.04
python-version: ["3.10.x", "3.11.x"]
defaults:
run:

View file

@ -6,7 +6,7 @@ on:
jobs:
github-releases-to-discord:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v3

View file

@ -22,7 +22,7 @@ jobs:
run_notebook_test:
name: test
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -36,7 +36,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true
@ -58,4 +58,4 @@ jobs:
--to notebook \
--execute ${{ inputs.notebook-location }} \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200
--ExecutePreprocessor.timeout=1200

View file

@ -22,7 +22,7 @@ jobs:
run_notebook_test:
name: test
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -33,10 +33,10 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
python-version: '3.12.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true
@ -49,7 +49,8 @@ jobs:
- name: Execute Python Example
env:
ENV: 'dev'
PYTHONFAULTHANDLER: 1
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: poetry run python ${{ inputs.example-location }}
run: poetry run python ${{ inputs.example-location }}

View file

@ -3,7 +3,7 @@ on: [ pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v2

View file

@ -3,7 +3,7 @@ on: [ pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v2

View file

@ -16,8 +16,7 @@ env:
jobs:
run_deduplication_test:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -46,7 +45,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -17,8 +17,7 @@ jobs:
run_milvus:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
strategy:
fail-fast: false
defaults:
@ -36,7 +35,7 @@ jobs:
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -15,8 +15,7 @@ env:
jobs:
run_neo4j_integration_test:
name: test
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
defaults:
run:
@ -32,7 +31,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -17,8 +17,7 @@ jobs:
run_pgvector_integration_test:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -47,7 +46,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -14,11 +14,9 @@ env:
ENV: 'dev'
jobs:
run_common:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
strategy:
fail-fast: false
defaults:
@ -36,7 +34,7 @@ jobs:
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -17,8 +17,7 @@ jobs:
run_common:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
strategy:
fail-fast: false
defaults:
@ -36,7 +35,7 @@ jobs:
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -1,4 +1,4 @@
name: test | python 3.9
name: test | python 3.12
on:
workflow_dispatch:
@ -17,8 +17,7 @@ jobs:
run_common:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
strategy:
fail-fast: false
defaults:
@ -32,11 +31,11 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.9.x'
python-version: '3.12.x'
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -17,9 +17,7 @@ jobs:
run_qdrant_integration_test:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -34,7 +32,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -17,9 +17,7 @@ jobs:
run_weaviate_integration_test:
name: test
runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
@ -34,7 +32,7 @@ jobs:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
uses: snok/install-poetry@v1.4.1
with:
virtualenvs-create: true
virtualenvs-in-project: true

View file

@ -25,7 +25,7 @@ RUN pip install poetry
RUN poetry config virtualenvs.create false
# Install the dependencies
RUN poetry install --all-extras --no-root --no-dev
RUN poetry install --all-extras --no-root --without dev
# Set the PYTHONPATH environment variable to include the /app directory

View file

@ -4,8 +4,9 @@ from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from cognee.infrastructure.databases.relational import Base
from alembic import context
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -20,7 +21,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from cognee.infrastructure.databases.relational import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
@ -83,12 +84,11 @@ def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
db_engine = get_relational_engine()
if db_engine.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
db_config = get_relational_config()
LocalStorage.ensure_directory_exists(db_config.db_path)

View file

@ -5,6 +5,7 @@ Revises: 8057ae7329c2
Create Date: 2024-10-16 22:17:18.634638
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
@ -13,8 +14,8 @@ from cognee.modules.users.methods import create_default_user, delete_user
# revision identifiers, used by Alembic.
revision: str = '482cd6517ce4'
down_revision: Union[str, None] = '8057ae7329c2'
revision: str = "482cd6517ce4"
down_revision: Union[str, None] = "8057ae7329c2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"

View file

@ -1,10 +1,11 @@
"""Initial migration
Revision ID: 8057ae7329c2
Revises:
Revises:
Create Date: 2024-10-02 12:55:20.989372
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
from cognee.infrastructure.databases.relational import get_relational_engine

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

BIN
assets/cognee_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 86 KiB

View file

@ -0,0 +1 @@
3.11.5

View file

@ -7,6 +7,7 @@ def main():
"""Main entry point for the package."""
asyncio.run(server.main())
# Optionally expose other important items at package level
__all__ = ["main", "server"]

View file

@ -18,7 +18,9 @@ def node_to_string(node):
# keys_to_keep = ["chunk_index", "topological_rank", "cut_type", "id", "text"]
# keyset = set(keys_to_keep) & node.keys()
# return "Node(" + " ".join([key + ": " + str(node[key]) + "," for key in keyset]) + ")"
node_data = ", ".join([f"{key}: \"{value}\"" for key, value in node.items() if key in ["id", "name"]])
node_data = ", ".join(
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
)
return f"Node({node_data})"
@ -52,9 +54,9 @@ async def handle_list_tools() -> list[types.Tool]:
"""
return [
types.Tool(
name = "cognify",
description = "Build knowledge graph from the input text.",
inputSchema = {
name="cognify",
description="Build knowledge graph from the input text.",
inputSchema={
"type": "object",
"properties": {
"text": {"type": "string"},
@ -65,9 +67,9 @@ async def handle_list_tools() -> list[types.Tool]:
},
),
types.Tool(
name = "search",
description = "Search the knowledge graph.",
inputSchema = {
name="search",
description="Search the knowledge graph.",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string"},
@ -76,9 +78,9 @@ async def handle_list_tools() -> list[types.Tool]:
},
),
types.Tool(
name = "prune",
description = "Reset the knowledge graph.",
inputSchema = {
name="prune",
description="Reset the knowledge graph.",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string"},
@ -90,8 +92,7 @@ async def handle_list_tools() -> list[types.Tool]:
@server.call_tool()
async def handle_call_tool(
name: str,
arguments: dict | None
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""
Handle tool execution requests.
@ -115,12 +116,12 @@ async def handle_call_tool(
await cognee.add(text)
await cognee.cognify(graph_model = graph_model)
await cognee.cognify(graph_model=graph_model)
return [
types.TextContent(
type = "text",
text = "Ingested",
type="text",
text="Ingested",
)
]
elif name == "search":
@ -131,16 +132,14 @@ async def handle_call_tool(
search_query = arguments.get("query")
search_results = await cognee.search(
SearchType.INSIGHTS, query_text = search_query
)
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
results = retrieved_edges_to_string(search_results)
return [
types.TextContent(
type = "text",
text = results,
type="text",
text=results,
)
]
elif name == "prune":
@ -151,8 +150,8 @@ async def handle_call_tool(
return [
types.TextContent(
type = "text",
text = "Pruned",
type="text",
text="Pruned",
)
]
else:
@ -166,15 +165,16 @@ async def main():
read_stream,
write_stream,
InitializationOptions(
server_name = "cognee-mcp",
server_version = "0.1.0",
capabilities = server.get_capabilities(
notification_options = NotificationOptions(),
experimental_capabilities = {},
server_name="cognee-mcp",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
# This is needed if you'd like to connect to a custom client
if __name__ == "__main__":
asyncio.run(main())

188
cognee-mcp/uv.lock generated
View file

@ -321,6 +321,26 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed", size = 147925 },
]
[[package]]
name = "bokeh"
version = "3.6.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "contourpy" },
{ name = "jinja2" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pandas" },
{ name = "pillow" },
{ name = "pyyaml" },
{ name = "tornado" },
{ name = "xyzservices" },
]
sdist = { url = "https://files.pythonhosted.org/packages/da/9d/cc9c561e1db8cbecc5cfad972159020700fff2339bdaa316498ace1cb04c/bokeh-3.6.2.tar.gz", hash = "sha256:2f3043d9ecb3d5dc2e8c0ebf8ad55727617188d4e534f3e7208b36357e352396", size = 6247610 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/56/12/2c266a0dc57379c60b4e73a2f93e71343db4170bf26c5a76a74e7d8bce2a/bokeh-3.6.2-py3-none-any.whl", hash = "sha256:fddc4b91f8b40178c0e3e83dfcc33886d7803a3a1f041a840834255e435a18c2", size = 6866799 },
]
[[package]]
name = "boto3"
version = "1.35.84"
@ -358,6 +378,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", size = 9524 },
]
[[package]]
name = "cairocffi"
version = "1.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi" },
]
sdist = { url = "https://files.pythonhosted.org/packages/70/c5/1a4dc131459e68a173cbdab5fad6b524f53f9c1ef7861b7698e998b837cc/cairocffi-1.7.1.tar.gz", hash = "sha256:2e48ee864884ec4a3a34bfa8c9ab9999f688286eb714a15a43ec9d068c36557b", size = 88096 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/93/d8/ba13451aa6b745c49536e87b6bf8f629b950e84bd0e8308f7dc6883b67e2/cairocffi-1.7.1-py3-none-any.whl", hash = "sha256:9803a0e11f6c962f3b0ae2ec8ba6ae45e957a146a004697a1ac1bbf16b073b3f", size = 75611 },
]
[[package]]
name = "cairosvg"
version = "2.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cairocffi" },
{ name = "cssselect2" },
{ name = "defusedxml" },
{ name = "pillow" },
{ name = "tinycss2" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d5/e6/ec5900b724e3c44af7f6f51f719919137284e5da4aabe96508baec8a1b40/CairoSVG-2.7.1.tar.gz", hash = "sha256:432531d72347291b9a9ebfb6777026b607563fd8719c46ee742db0aef7271ba0", size = 8399085 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/01/a5/1866b42151f50453f1a0d28fc4c39f5be5f412a2e914f33449c42daafdf1/CairoSVG-2.7.1-py3-none-any.whl", hash = "sha256:8a5222d4e6c3f86f1f7046b63246877a63b49923a1cd202184c3a634ef546b3b", size = 43235 },
]
[[package]]
name = "certifi"
version = "2024.12.14"
@ -412,6 +460,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
]
[[package]]
name = "cfgv"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
]
[[package]]
name = "chardet"
version = "5.2.0"
@ -480,7 +537,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@ -497,7 +554,9 @@ dependencies = [
{ name = "aiosqlite" },
{ name = "alembic" },
{ name = "anthropic" },
{ name = "bokeh" },
{ name = "boto3" },
{ name = "cairosvg" },
{ name = "datasets" },
{ name = "dlt", extra = ["sqlalchemy"] },
{ name = "fastapi" },
@ -518,6 +577,7 @@ dependencies = [
{ name = "numpy" },
{ name = "openai" },
{ name = "pandas" },
{ name = "pre-commit" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pypdf" },
@ -541,8 +601,10 @@ requires-dist = [
{ name = "alembic", specifier = ">=1.13.3,<2.0.0" },
{ name = "anthropic", specifier = ">=0.26.1,<0.27.0" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = "==0.30.0" },
{ name = "bokeh", specifier = ">=3.6.2,<4.0.0" },
{ name = "boto3", specifier = ">=1.26.125,<2.0.0" },
{ name = "botocore", marker = "extra == 'filesystem'", specifier = ">=1.35.54,<2.0.0" },
{ name = "cairosvg", specifier = ">=2.7.1,<3.0.0" },
{ name = "datasets", specifier = "==3.1.0" },
{ name = "deepeval", marker = "extra == 'deepeval'", specifier = ">=2.0.1,<3.0.0" },
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
@ -573,6 +635,7 @@ requires-dist = [
{ name = "pandas", specifier = "==2.0.3" },
{ name = "pgvector", marker = "extra == 'postgres'", specifier = ">=0.3.5,<0.4.0" },
{ name = "posthog", marker = "extra == 'posthog'", specifier = ">=3.5.0,<4.0.0" },
{ name = "pre-commit", specifier = ">=4.0.1,<5.0.0" },
{ name = "psycopg2", marker = "extra == 'postgres'", specifier = ">=2.9.10,<3.0.0" },
{ name = "pydantic", specifier = "==2.8.2" },
{ name = "pydantic-settings", specifier = ">=2.2.1,<3.0.0" },
@ -885,6 +948,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/97/9b/443270b9210f13f6ef240eff73fd32e02d381e7103969dc66ce8e89ee901/cryptography-44.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:708ee5f1bafe76d041b53a4f95eb28cdeb8d18da17e597d46d7833ee59b97ede", size = 3202071 },
]
[[package]]
name = "cssselect2"
version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "tinycss2" },
{ name = "webencodings" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e7/fc/326cb6f988905998f09bb54a3f5d98d4462ba119363c0dfad29750d48c09/cssselect2-0.7.0.tar.gz", hash = "sha256:1ccd984dab89fc68955043aca4e1b03e0cf29cad9880f6e28e3ba7a74b14aa5a", size = 35888 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9d/3a/e39436efe51894243ff145a37c4f9a030839b97779ebcc4f13b3ba21c54e/cssselect2-0.7.0-py3-none-any.whl", hash = "sha256:fd23a65bfd444595913f02fc71f6b286c29261e354c41d722ca7a261a49b5969", size = 15586 },
]
[[package]]
name = "cycler"
version = "0.12.1"
@ -996,6 +1072,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/39/60/533ce66e28295e2b94267126a851ac091ad29a835a9827d1f9c30574fce4/deepeval-2.0.6-py3-none-any.whl", hash = "sha256:57302830ff9d3d16ad4f1961338c7b4453e48039ff131990f258880728f33b6b", size = 504101 },
]
[[package]]
name = "defusedxml"
version = "0.7.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
]
[[package]]
name = "deprecated"
version = "1.2.15"
@ -1056,6 +1141,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197 },
]
[[package]]
name = "distlib"
version = "0.3.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
]
[[package]]
name = "distro"
version = "1.9.0"
@ -1686,6 +1780,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d7/de/85a784bcc4a3779d1753a7ec2dee5de90e18c7bcf402e71b51fcf150b129/hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15", size = 12389 },
]
[[package]]
name = "identify"
version = "2.6.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/1a/5f/05f0d167be94585d502b4adf8c7af31f1dc0b1c7e14f9938a88fdbbcf4a7/identify-2.6.3.tar.gz", hash = "sha256:62f5dae9b5fef52c84cc188514e9ea4f3f636b1d8799ab5ebc475471f9e47a02", size = 99179 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/f5/09644a3ad803fae9eca8efa17e1f2aef380c7f0b02f7ec4e8d446e51d64a/identify-2.6.3-py2.py3-none-any.whl", hash = "sha256:9edba65473324c2ea9684b1f944fe3191db3345e50b6d04571d10ed164f8d7bd", size = 99049 },
]
[[package]]
name = "idna"
version = "3.10"
@ -2557,6 +2660,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 },
]
[[package]]
name = "nodeenv"
version = "1.9.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 },
]
[[package]]
name = "numpy"
version = "1.26.4"
@ -2946,7 +3058,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@ -2969,6 +3081,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d3/f2/5ee24cd69e2120bf87356c02ace0438b4e4fb78229fddcbf6f1c6be377d5/posthog-3.7.4-py2.py3-none-any.whl", hash = "sha256:21c18c6bf43b2de303ea4cd6e95804cc0f24c20cb2a96a8fd09da2ed50b62faa", size = 54777 },
]
[[package]]
name = "pre-commit"
version = "4.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cfgv" },
{ name = "identify" },
{ name = "nodeenv" },
{ name = "pyyaml" },
{ name = "virtualenv" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 },
]
[[package]]
name = "propcache"
version = "0.2.1"
@ -3065,6 +3193,7 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/33/39/5a9a229bb5414abeb86e33b8fc8143ab0aecce5a7f698a53e31367d30caa/psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4", size = 1163736 },
{ url = "https://files.pythonhosted.org/packages/3d/16/4623fad6076448df21c1a870c93a9774ad8a7b4dd1660223b59082dd8fec/psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067", size = 1025113 },
{ url = "https://files.pythonhosted.org/packages/66/de/baed128ae0fc07460d9399d82e631ea31a1f171c0c4ae18f9808ac6759e3/psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e", size = 1163951 },
{ url = "https://files.pythonhosted.org/packages/ae/49/a6cfc94a9c483b1fa401fbcb23aca7892f60c7269c5ffa2ac408364f80dc/psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2", size = 2569060 },
]
[[package]]
@ -4223,6 +4352,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/35/75/c4d8b2f0fe7dac22854d88a9c509d428e78ac4bf284bc54cfe83f75cc13b/time_machine-2.16.0-cp313-cp313-win_arm64.whl", hash = "sha256:4d3843143c46dddca6491a954bbd0abfd435681512ac343169560e9bab504129", size = 18047 },
]
[[package]]
name = "tinycss2"
version = "1.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "webencodings" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610 },
]
[[package]]
name = "tokenizers"
version = "0.21.0"
@ -4257,12 +4398,30 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 },
]
[[package]]
name = "tornado"
version = "6.4.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/59/45/a0daf161f7d6f36c3ea5fc0c2de619746cc3dd4c76402e9db545bd920f63/tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b", size = 501135 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/26/7e/71f604d8cea1b58f82ba3590290b66da1e72d840aeb37e0d5f7291bd30db/tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1", size = 436299 },
{ url = "https://files.pythonhosted.org/packages/96/44/87543a3b99016d0bf54fdaab30d24bf0af2e848f1d13d34a3a5380aabe16/tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803", size = 434253 },
{ url = "https://files.pythonhosted.org/packages/cb/fb/fdf679b4ce51bcb7210801ef4f11fdac96e9885daa402861751353beea6e/tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec", size = 437602 },
{ url = "https://files.pythonhosted.org/packages/4f/3b/e31aeffffc22b475a64dbeb273026a21b5b566f74dee48742817626c47dc/tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946", size = 436972 },
{ url = "https://files.pythonhosted.org/packages/22/55/b78a464de78051a30599ceb6983b01d8f732e6f69bf37b4ed07f642ac0fc/tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf", size = 437173 },
{ url = "https://files.pythonhosted.org/packages/79/5e/be4fb0d1684eb822c9a62fb18a3e44a06188f78aa466b2ad991d2ee31104/tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634", size = 437892 },
{ url = "https://files.pythonhosted.org/packages/f5/33/4f91fdd94ea36e1d796147003b490fe60a0215ac5737b6f9c65e160d4fe0/tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73", size = 437334 },
{ url = "https://files.pythonhosted.org/packages/2b/ae/c1b22d4524b0e10da2f29a176fb2890386f7bd1f63aacf186444873a88a0/tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c", size = 437261 },
{ url = "https://files.pythonhosted.org/packages/b5/25/36dbd49ab6d179bcfc4c6c093a51795a4f3bed380543a8242ac3517a1751/tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482", size = 438463 },
{ url = "https://files.pythonhosted.org/packages/61/cc/58b1adeb1bb46228442081e746fcdbc4540905c87e8add7c277540934edb/tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38", size = 438907 },
]
[[package]]
name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
wheels = [
@ -4536,6 +4695,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/22/91b4bd36df27e651daedd93d03d5d3bb6029fdb0b55494e45ee46c36c570/validators-0.33.0-py3-none-any.whl", hash = "sha256:134b586a98894f8139865953899fc2daeb3d0c35569552c5518f089ae43ed075", size = 43298 },
]
[[package]]
name = "virtualenv"
version = "20.28.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "distlib" },
{ name = "filelock" },
{ name = "platformdirs" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bf/75/53316a5a8050069228a2f6d11f32046cfa94fbb6cc3f08703f59b873de2e/virtualenv-20.28.0.tar.gz", hash = "sha256:2c9c3262bb8e7b87ea801d715fae4495e6032450c71d2309be9550e7364049aa", size = 7650368 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/10/f9/0919cf6f1432a8c4baa62511f8f8da8225432d22e83e3476f5be1a1edc6e/virtualenv-20.28.0-py3-none-any.whl", hash = "sha256:23eae1b4516ecd610481eda647f3a7c09aea295055337331bb4e6892ecce47b0", size = 4276702 },
]
[[package]]
name = "weaviate-client"
version = "4.6.7"
@ -4692,6 +4865,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/27/ee/518b72faa2073f5aa8e3262408d284892cb79cf2754ba0c3a5870645ef73/xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b", size = 26801 },
]
[[package]]
name = "xyzservices"
version = "2024.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a0/16/ae87cbd2d6bfc40a419077521c35aadf5121725b7bee3d7c51b56f50958b/xyzservices-2024.9.0.tar.gz", hash = "sha256:68fb8353c9dbba4f1ff6c0f2e5e4e596bb9e1db7f94f4f7dfbcb26e25aa66fde", size = 1131900 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/d3/e07ce413d16ef64e885bea37551eac4c5ca3ddd440933f9c94594273d0d9/xyzservices-2024.9.0-py3-none-any.whl", hash = "sha256:776ae82b78d6e5ca63dd6a94abb054df8130887a4a308473b54a6bd364de8644", size = 85130 },
]
[[package]]
name = "yarl"
version = "1.18.3"

View file

@ -4,11 +4,15 @@ from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
from .api.v1.visualize import visualize
from .shared.utils import create_cognee_style_network_with_logo
# Pipelines
from .modules import pipelines
try:
import dotenv
dotenv.load_dotenv()
except ImportError:
pass

View file

@ -4,12 +4,13 @@ from pydantic.alias_generators import to_camel, to_snake
class OutDTO(BaseModel):
model_config = ConfigDict(
alias_generator = to_camel,
populate_by_name = True,
alias_generator=to_camel,
populate_by_name=True,
)
class InDTO(BaseModel):
model_config = ConfigDict(
alias_generator = to_camel,
populate_by_name = True,
alias_generator=to_camel,
populate_by_name=True,
)

View file

@ -1,4 +1,5 @@
""" FastAPI server for the Cognee API. """
"""FastAPI server for the Cognee API."""
import os
import uvicorn
import logging
@ -6,9 +7,26 @@ import sentry_sdk
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from cognee.api.v1.permissions.routers import get_permissions_router
from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_cognify_router
from cognee.api.v1.search.routers import get_search_router
from cognee.api.v1.add.routers import get_add_router
from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from cognee.exceptions import CogneeApiError
from traceback import format_exc
from cognee.api.v1.users.routers import (
get_auth_router,
get_register_router,
get_reset_password_router,
get_verify_router,
get_users_router,
get_visualize_router,
)
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(
@ -19,15 +37,15 @@ logger = logging.getLogger(__name__)
if os.getenv("ENV", "prod") == "prod":
sentry_sdk.init(
dsn = os.getenv("SENTRY_REPORTING_URL"),
traces_sample_rate = 1.0,
profiles_sample_rate = 1.0,
dsn=os.getenv("SENTRY_REPORTING_URL"),
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
)
from contextlib import asynccontextmanager
app_environment = os.getenv("ENV", "prod")
@asynccontextmanager
async def lifespan(app: FastAPI):
# from cognee.modules.data.deletion import prune_system, prune_data
@ -35,50 +53,42 @@ async def lifespan(app: FastAPI):
# await prune_system(metadata = True)
# if app_environment == "local" or app_environment == "dev":
from cognee.infrastructure.databases.relational import get_relational_engine
db_engine = get_relational_engine()
await db_engine.create_database()
from cognee.modules.users.methods import get_default_user
await get_default_user()
yield
app = FastAPI(debug = app_environment != "prod", lifespan = lifespan)
app = FastAPI(debug=app_environment != "prod", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["OPTIONS", "GET", "POST", "DELETE"],
allow_headers = ["*"],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["OPTIONS", "GET", "POST", "DELETE"],
allow_headers=["*"],
)
from cognee.api.v1.users.routers import get_auth_router, get_register_router,\
get_reset_password_router, get_verify_router, get_users_router
from cognee.api.v1.permissions.routers import get_permissions_router
from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_cognify_router
from cognee.api.v1.search.routers import get_search_router
from cognee.api.v1.add.routers import get_add_router
from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
@app.exception_handler(RequestValidationError)
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
if request.url.path == "/api/v1/auth/login":
return JSONResponse(
status_code = 400,
content = {"detail": "LOGIN_BAD_CREDENTIALS"},
status_code=400,
content={"detail": "LOGIN_BAD_CREDENTIALS"},
)
return JSONResponse(
status_code = 400,
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
status_code=400,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
@app.exception_handler(CogneeApiError)
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
detail = {}
@ -95,46 +105,42 @@ async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
# log the stack trace for easier serverside debugging
logger.error(format_exc())
return JSONResponse(
status_code=status_code, content={"detail": detail["message"]}
)
return JSONResponse(status_code=status_code, content={"detail": detail["message"]})
app.include_router(
get_auth_router(),
prefix = "/api/v1/auth",
tags = ["auth"]
)
app.include_router(get_auth_router(), prefix="/api/v1/auth", tags=["auth"])
app.include_router(
get_register_router(),
prefix = "/api/v1/auth",
tags = ["auth"],
prefix="/api/v1/auth",
tags=["auth"],
)
app.include_router(
get_reset_password_router(),
prefix = "/api/v1/auth",
tags = ["auth"],
prefix="/api/v1/auth",
tags=["auth"],
)
app.include_router(
get_verify_router(),
prefix = "/api/v1/auth",
tags = ["auth"],
prefix="/api/v1/auth",
tags=["auth"],
)
app.include_router(
get_users_router(),
prefix = "/api/v1/users",
tags = ["users"],
prefix="/api/v1/users",
tags=["users"],
)
app.include_router(
get_permissions_router(),
prefix = "/api/v1/permissions",
tags = ["permissions"],
prefix="/api/v1/permissions",
tags=["permissions"],
)
@app.get("/")
async def root():
"""
@ -148,37 +154,21 @@ def health_check():
"""
Health check endpoint that returns the server status.
"""
return Response(status_code = 200)
return Response(status_code=200)
app.include_router(
get_datasets_router(),
prefix="/api/v1/datasets",
tags=["datasets"]
)
app.include_router(
get_add_router(),
prefix="/api/v1/add",
tags=["add"]
)
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
app.include_router(
get_cognify_router(),
prefix="/api/v1/cognify",
tags=["cognify"]
)
app.include_router(get_add_router(), prefix="/api/v1/add", tags=["add"])
app.include_router(
get_search_router(),
prefix="/api/v1/search",
tags=["search"]
)
app.include_router(get_cognify_router(), prefix="/api/v1/cognify", tags=["cognify"])
app.include_router(get_search_router(), prefix="/api/v1/search", tags=["search"])
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
app.include_router(
get_settings_router(),
prefix="/api/v1/settings",
tags=["settings"]
)
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
"""
@ -190,7 +180,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
try:
logger.info("Starting server at %s:%s", host, port)
uvicorn.run(app, host = host, port = port)
uvicorn.run(app, host=host, port=port)
except Exception as e:
logger.exception(f"Failed to start server: {e}")
# Here you could add any cleanup code or error recovery code.

View file

@ -14,10 +14,19 @@ from cognee.tasks.ingestion import get_dlt_destination
from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.modules.users.models import User
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
async def add(
data: Union[BinaryIO, List[BinaryIO], str, List[str]],
dataset_name: str = "main_dataset",
user: User = None,
):
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
@ -25,7 +34,9 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
if "data://" in data:
# data is a data directory path
datasets = get_matched_datasets(data.replace("data://", ""), dataset_name)
return await asyncio.gather(*[add(file_paths, dataset_name) for [dataset_name, file_paths] in datasets])
return await asyncio.gather(
*[add(file_paths, dataset_name) for [dataset_name, file_paths] in datasets]
)
if "file://" in data:
# data is a file path
@ -37,7 +48,7 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
return await add([file_path], dataset_name)
if hasattr(data, "file"):
file_path = save_data_to_file(data.file, filename = data.filename)
file_path = save_data_to_file(data.file, filename=data.filename)
return await add([file_path], dataset_name)
# data is a list of file paths or texts
@ -45,7 +56,7 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
for data_item in data:
if hasattr(data_item, "file"):
file_paths.append(save_data_to_file(data_item, filename = data_item.filename))
file_paths.append(save_data_to_file(data_item, filename=data_item.filename))
elif isinstance(data_item, str) and (
data_item.startswith("/") or data_item.startswith("file://")
):
@ -58,10 +69,11 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
return []
async def add_files(file_paths: List[str], dataset_name: str, user: User = None):
if user is None:
user = await get_default_user()
base_config = get_base_config()
data_directory_path = base_config.data_root_directory
@ -72,7 +84,11 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
if data_directory_path not in file_path:
file_name = file_path.split("/")[-1]
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "")
file_directory_path = (
data_directory_path
+ "/"
+ (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "")
)
dataset_file_path = path.join(file_directory_path, file_name)
LocalStorage.ensure_directory_exists(file_directory_path)
@ -85,16 +101,20 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
destination = get_dlt_destination()
pipeline = dlt.pipeline(
pipeline_name = "file_load_from_filesystem",
destination = destination,
pipeline_name="file_load_from_filesystem",
destination=destination,
)
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
dataset_name = (
dataset_name.replace(" ", "_").replace(".", "_")
if dataset_name is not None
else "main_dataset"
)
@dlt.resource(standalone = True, merge_key = "id")
@dlt.resource(standalone=True, merge_key="id")
async def data_resources(file_paths: str, user: User):
for file_path in file_paths:
with open(file_path.replace("file://", ""), mode = "rb") as file:
with open(file_path.replace("file://", ""), mode="rb") as file:
classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data)
@ -109,9 +129,9 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user.id, session)
data = (await session.execute(
select(Data).filter(Data.id == data_id)
)).scalar_one_or_none()
data = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data is not None:
data.name = file_metadata["name"]
@ -123,11 +143,11 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
await session.commit()
else:
data = Data(
id = data_id,
name = file_metadata["name"],
raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"],
mime_type = file_metadata["mime_type"],
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
)
dataset.data.append(data)
@ -144,14 +164,13 @@ async def add_files(file_paths: List[str], dataset_name: str, user: User = None)
await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write")
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
run_info = pipeline.run(
data_resources(processed_file_paths, user),
table_name = "file_metadata",
dataset_name = dataset_name,
write_disposition = "merge",
table_name="file_metadata",
dataset_name=dataset_name,
write_disposition="merge",
)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
return run_info

View file

@ -3,22 +3,28 @@ from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task
from cognee.tasks.ingestion import ingest_data_with_metadata, resolve_data_directories
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None):
async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_name: str = "main_dataset",
user: User = None,
):
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
if user is None:
user = await get_default_user()
tasks = [
Task(resolve_data_directories),
Task(ingest_data_with_metadata, dataset_name, user)
]
tasks = [Task(resolve_data_directories), Task(ingest_data_with_metadata, dataset_name, user)]
pipeline = run_tasks(tasks, data, "add_pipeline")
async for result in pipeline:
print(result)
print(result)

View file

@ -1 +1 @@
from .get_add_router import get_add_router
from .get_add_router import get_add_router

View file

@ -11,17 +11,19 @@ from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__)
def get_add_router() -> APIRouter:
router = APIRouter()
@router.post("/", response_model=None)
async def add(
data: List[UploadFile],
datasetId: str = Form(...),
user: User = Depends(get_authenticated_user),
data: List[UploadFile],
datasetId: str = Form(...),
user: User = Depends(get_authenticated_user),
):
""" This endpoint is responsible for adding data to the graph."""
"""This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.add import add as cognee_add
try:
if isinstance(data, str) and data.startswith("http"):
if "github" in data:
@ -52,9 +54,6 @@ def get_add_router() -> APIRouter:
user=user,
)
except Exception as error:
return JSONResponse(
status_code=409,
content={"error": str(error)}
)
return JSONResponse(status_code=409, content={"error": str(error)})
return router
return router

View file

@ -1,4 +1,6 @@
from cognee.infrastructure.databases.relational.user_authentication.users import authenticate_user_method
from cognee.infrastructure.databases.relational.user_authentication.users import (
authenticate_user_method,
)
async def authenticate_user(email: str, password: str):
@ -11,6 +13,7 @@ async def authenticate_user(email: str, password: str):
if __name__ == "__main__":
import asyncio
# Define an example user
example_email = "example@example.com"
example_password = "securepassword123"

View file

@ -3,31 +3,30 @@ import logging
from pathlib import Path
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import \
get_embedding_engine
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
from cognee.tasks.documents import (classify_documents,
extract_chunks_from_documents)
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data_with_metadata
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_data_list_for_user,
get_non_py_files,
get_repo_file_dependencies)
from cognee.tasks.repo_processor.get_source_code_chunks import \
get_source_code_chunks
from cognee.tasks.repo_processor import (
enrich_dependency_graph,
expand_dependency_graph,
get_data_list_for_user,
get_non_py_files,
get_repo_file_dependencies,
)
from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code, summarize_text
monitoring = get_base_config().monitoring_tool
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe
from cognee.tasks.summarization import summarize_code, summarize_text
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
@ -42,9 +41,13 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
from cognee.infrastructure.databases.relational import create_db_and_tables
file_path = Path(__file__).parent
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
data_directory_path = str(
pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
cognee_directory_path = str(
pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
@ -60,7 +63,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(
get_source_code_chunks,
embedding_model=embedding_engine.model,
task_config={"batch_size": 50},
),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
@ -72,17 +79,19 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
),
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 50}
task_config={"batch_size": 50},
),
]
if include_docs:
async for result in run_tasks(non_code_tasks, repo_path):
yield result
async for result in run_tasks(tasks, repo_path, "cognify_code_pipeline"):
yield result
yield result

View file

@ -10,18 +10,18 @@ from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import \
get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import \
log_pipeline_status
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import send_telemetry
from cognee.tasks.documents import (check_permissions_on_documents,
classify_documents,
extract_chunks_from_documents)
from cognee.tasks.documents import (
check_permissions_on_documents,
classify_documents,
extract_chunks_from_documents,
)
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
@ -31,7 +31,12 @@ logger = logging.getLogger("cognify.v2")
update_status_lock = asyncio.Lock()
async def cognify(datasets: Union[str, list[str]] = None, user: User = None, graph_model: BaseModel = KnowledgeGraph):
async def cognify(
datasets: Union[str, list[str]] = None,
user: User = None,
graph_model: BaseModel = KnowledgeGraph,
):
if user is None:
user = await get_default_user()
@ -41,7 +46,7 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None, gra
# If no datasets are provided, cognify all existing datasets.
datasets = existing_datasets
if type(datasets[0]) == str:
if isinstance(datasets[0], str):
datasets = await get_datasets_by_name(datasets, user.id)
existing_datasets_map = {
@ -59,8 +64,10 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None, gra
return await asyncio.gather(*awaitables)
async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseModel = KnowledgeGraph):
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
async def run_cognify_pipeline(
dataset: Dataset, user: User, graph_model: BaseModel = KnowledgeGraph
):
data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)
document_ids_str = [str(document.id) for document in data_documents]
@ -69,32 +76,41 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
send_telemetry("cognee.cognify EXECUTION STARTED", user.id)
#async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
if (
dataset_id in task_status
and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED
):
logger.info("Dataset %s is already being processed.", dataset_name)
return
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
await log_pipeline_status(
dataset_id,
PipelineRunStatus.DATASET_PROCESSING_STARTED,
{
"dataset_name": dataset_name,
"files": document_ids_str,
},
)
try:
cognee_config = get_cognify_config()
tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(extract_graph_from_data, graph_model = graph_model, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model = cognee_config.summarization_model,
task_config = { "batch_size": 10 }
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
@ -106,17 +122,25 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
await log_pipeline_status(
dataset_id,
PipelineRunStatus.DATASET_PROCESSING_COMPLETED,
{
"dataset_name": dataset_name,
"files": document_ids_str,
},
)
except Exception as error:
send_telemetry("cognee.cognify EXECUTION ERRORED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
await log_pipeline_status(
dataset_id,
PipelineRunStatus.DATASET_PROCESSING_ERRORED,
{
"dataset_name": dataset_name,
"files": document_ids_str,
},
)
raise error

View file

@ -1 +1 @@
from .get_cognify_router import get_cognify_router
from .get_cognify_router import get_cognify_router

View file

@ -7,23 +7,23 @@ from cognee.modules.users.methods import get_authenticated_user
from fastapi import Depends
from cognee.shared.data_models import KnowledgeGraph
class CognifyPayloadDTO(BaseModel):
datasets: List[str]
graph_model: Optional[BaseModel] = KnowledgeGraph
def get_cognify_router() -> APIRouter:
router = APIRouter()
@router.post("/", response_model=None)
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for the cognitive processing of the content."""
"""This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets, user, payload.graph_model)
except Exception as error:
return JSONResponse(
status_code=409,
content={"error": str(error)}
)
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -1,4 +1,5 @@
""" This module is used to set the configuration of the system."""
"""This module is used to set the configuration of the system."""
import os
from cognee.base_config import get_base_config
from cognee.exceptions import InvalidValueError, InvalidAttributeError
@ -10,7 +11,8 @@ from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.databases.relational import get_relational_config
from cognee.infrastructure.files.storage import LocalStorage
class config():
class config:
@staticmethod
def system_root_directory(system_root_directory: str):
databases_directory_path = os.path.join(system_root_directory, "databases")
@ -39,12 +41,12 @@ class config():
@staticmethod
def set_classification_model(classification_model: object):
cognify_config = get_cognify_config()
cognify_config.classification_model = classification_model
cognify_config.classification_model = classification_model
@staticmethod
def set_summarization_model(summarization_model: object):
cognify_config = get_cognify_config()
cognify_config.summarization_model=summarization_model
cognify_config.summarization_model = summarization_model
@staticmethod
def set_graph_model(graph_model: object):
@ -79,14 +81,16 @@ class config():
@staticmethod
def set_llm_config(config_dict: dict):
"""
Updates the llm config with values from config_dict.
Updates the llm config with values from config_dict.
"""
llm_config = get_llm_config()
for key, value in config_dict.items():
if hasattr(llm_config, key):
object.__setattr__(llm_config, key, value)
else:
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
raise InvalidAttributeError(
message=f"'{key}' is not a valid attribute of the config."
)
@staticmethod
def set_chunk_strategy(chunk_strategy: object):
@ -108,7 +112,6 @@ class config():
chunk_config = get_chunk_config()
chunk_config.chunk_size = chunk_size
@staticmethod
def set_vector_db_provider(vector_db_provider: str):
vector_db_config = get_vectordb_config()
@ -117,33 +120,36 @@ class config():
@staticmethod
def set_relational_db_config(config_dict: dict):
"""
Updates the relational db config with values from config_dict.
Updates the relational db config with values from config_dict.
"""
relational_db_config = get_relational_config()
for key, value in config_dict.items():
if hasattr(relational_db_config, key):
object.__setattr__(relational_db_config, key, value)
else:
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
raise InvalidAttributeError(
message=f"'{key}' is not a valid attribute of the config."
)
@staticmethod
def set_vector_db_config(config_dict: dict):
"""
Updates the vector db config with values from config_dict.
Updates the vector db config with values from config_dict.
"""
vector_db_config = get_vectordb_config()
for key, value in config_dict.items():
if hasattr(vector_db_config, key):
object.__setattr__(vector_db_config, key, value)
else:
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
raise InvalidAttributeError(
message=f"'{key}' is not a valid attribute of the config."
)
@staticmethod
def set_vector_db_key(db_key: str):
vector_db_config = get_vectordb_config()
vector_db_config.vector_db_key = db_key
@staticmethod
def set_vector_db_url(db_url: str):
vector_db_config = get_vectordb_config()
@ -154,7 +160,9 @@ class config():
base_config = get_base_config()
if "username" not in graphistry_config or "password" not in graphistry_config:
raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.")
raise InvalidValueError(
message="graphistry_config dictionary must contain 'username' and 'password' keys."
)
base_config.graphistry_username = graphistry_config.get("username")
base_config.graphistry_password = graphistry_config.get("password")

View file

@ -1,10 +1,13 @@
from cognee.modules.users.methods import get_default_user
from cognee.modules.ingestion import discover_directory_datasets
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
class datasets():
class datasets:
@staticmethod
async def list_datasets():
from cognee.modules.data.methods import get_datasets
user = await get_default_user()
return await get_datasets(user.id)
@ -15,6 +18,7 @@ class datasets():
@staticmethod
async def list_data(dataset_id: str):
from cognee.modules.data.methods import get_dataset, get_dataset_data
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_id)
@ -28,6 +32,7 @@ class datasets():
@staticmethod
async def delete_dataset(dataset_id: str):
from cognee.modules.data.methods import get_dataset, delete_dataset
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_id)

View file

@ -1 +1 @@
from .get_datasets_router import get_datasets_router
from .get_datasets_router import get_datasets_router

View file

@ -16,9 +16,11 @@ from cognee.modules.pipelines.models import PipelineRunStatus
logger = logging.getLogger(__name__)
class ErrorResponseDTO(BaseModel):
message: str
class DatasetDTO(OutDTO):
id: UUID
name: str
@ -26,6 +28,7 @@ class DatasetDTO(OutDTO):
updated_at: Optional[datetime] = None
owner_id: UUID
class DataDTO(OutDTO):
id: UUID
name: str
@ -35,6 +38,7 @@ class DataDTO(OutDTO):
mime_type: str
raw_data_location: str
def get_datasets_router() -> APIRouter:
router = APIRouter()
@ -42,46 +46,51 @@ def get_datasets_router() -> APIRouter:
async def get_datasets(user: User = Depends(get_authenticated_user)):
try:
from cognee.modules.data.methods import get_datasets
datasets = await get_datasets(user.id)
return datasets
except Exception as error:
logger.error(f"Error retrieving datasets: {str(error)}")
raise HTTPException(status_code=500, detail=f"Error retrieving datasets: {str(error)}") from error
raise HTTPException(
status_code=500, detail=f"Error retrieving datasets: {str(error)}"
) from error
@router.delete("/{dataset_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}})
@router.delete(
"/{dataset_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}}
)
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset, delete_dataset
dataset = await get_dataset(user.id, dataset_id)
if dataset is None:
raise EntityNotFoundError(
message=f"Dataset ({dataset_id}) not found."
)
raise EntityNotFoundError(message=f"Dataset ({dataset_id}) not found.")
await delete_dataset(dataset)
@router.delete("/{dataset_id}/data/{data_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}})
async def delete_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
@router.delete(
"/{dataset_id}/data/{data_id}",
response_model=None,
responses={404: {"model": ErrorResponseDTO}},
)
async def delete_data(
dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)
):
from cognee.modules.data.methods import get_data, delete_data
from cognee.modules.data.methods import get_dataset
# Check if user has permission to access dataset and data by trying to get the dataset
dataset = await get_dataset(user.id, dataset_id)
#TODO: Handle situation differently if user doesn't have permission to access data?
# TODO: Handle situation differently if user doesn't have permission to access data?
if dataset is None:
raise EntityNotFoundError(
message=f"Dataset ({dataset_id}) not found."
)
raise EntityNotFoundError(message=f"Dataset ({dataset_id}) not found.")
data = await get_data(user.id, data_id)
if data is None:
raise EntityNotFoundError(
message=f"Data ({data_id}) not found."
)
raise EntityNotFoundError(message=f"Data ({data_id}) not found.")
await delete_data(data)
@ -98,14 +107,18 @@ def get_datasets_router() -> APIRouter:
status_code=200,
content=str(graph_url),
)
except:
except Exception as error:
print(error)
return JSONResponse(
status_code=409,
content="Graphistry credentials are not set. Please set them in your .env file.",
)
@router.get("/{dataset_id}/data", response_model=list[DataDTO],
responses={404: {"model": ErrorResponseDTO}})
@router.get(
"/{dataset_id}/data",
response_model=list[DataDTO],
responses={404: {"model": ErrorResponseDTO}},
)
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset_data, get_dataset
@ -125,8 +138,10 @@ def get_datasets_router() -> APIRouter:
return dataset_data
@router.get("/status", response_model=dict[str, PipelineRunStatus])
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None,
user: User = Depends(get_authenticated_user)):
async def get_dataset_status(
datasets: Annotated[List[str], Query(alias="dataset")] = None,
user: User = Depends(get_authenticated_user),
):
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
try:
@ -134,13 +149,12 @@ def get_datasets_router() -> APIRouter:
return datasets_statuses
except Exception as error:
return JSONResponse(
status_code=409,
content={"error": str(error)}
)
return JSONResponse(status_code=409, content={"error": str(error)})
@router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
async def get_raw_data(
dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)
):
from cognee.modules.data.methods import get_data
from cognee.modules.data.methods import get_dataset, get_dataset_data
@ -148,10 +162,7 @@ def get_datasets_router() -> APIRouter:
if dataset is None:
return JSONResponse(
status_code=404,
content={
"detail": f"Dataset ({dataset_id}) not found."
}
status_code=404, content={"detail": f"Dataset ({dataset_id}) not found."}
)
dataset_data = await get_dataset_data(dataset.id)
@ -163,13 +174,17 @@ def get_datasets_router() -> APIRouter:
# Check if matching_data contains an element
if len(matching_data) == 0:
raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).")
raise EntityNotFoundError(
message=f"Data ({data_id}) not found in dataset ({dataset_id})."
)
data = await get_data(user.id, data_id)
if data is None:
raise EntityNotFoundError(message=f"Data ({data_id}) not found in dataset ({dataset_id}).")
raise EntityNotFoundError(
message=f"Data ({data_id}) not found in dataset ({dataset_id})."
)
return data.raw_data_location
return router
return router

View file

@ -1 +1 @@
from .get_permissions_router import get_permissions_router
from .get_permissions_router import get_permissions_router

View file

@ -10,40 +10,56 @@ from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundErro
from cognee.modules.users import get_user_db
from cognee.modules.users.models import User, Group, Permission, UserGroup, GroupPermission
def get_permissions_router() -> APIRouter:
permissions_router = APIRouter()
@permissions_router.post("/groups/{group_id}/permissions")
async def give_permission_to_group(group_id: str, permission: str, db: Session = Depends(get_user_db)):
group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
async def give_permission_to_group(
group_id: str, permission: str, db: Session = Depends(get_user_db)
):
group = (
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
)
if not group:
raise GroupNotFoundError
permission_entity = (
await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first()
(await db.session.execute(select(Permission).where(Permission.name == permission)))
.scalars()
.first()
)
if not permission_entity:
stmt = insert(Permission).values(name=permission)
await db.session.execute(stmt)
permission_entity = (
await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first()
(await db.session.execute(select(Permission).where(Permission.name == permission)))
.scalars()
.first()
)
try:
# add permission to group
await db.session.execute(
insert(GroupPermission).values(group_id=group.id, permission_id=permission_entity.id))
except IntegrityError as e:
insert(GroupPermission).values(
group_id=group.id, permission_id=permission_entity.id
)
)
except IntegrityError:
raise EntityAlreadyExistsError(message="Group permission already exists.")
await db.session.commit()
return JSONResponse(status_code = 200, content = {"message": "Permission assigned to group"})
return JSONResponse(status_code=200, content={"message": "Permission assigned to group"})
@permissions_router.post("/users/{user_id}/groups")
async def add_user_to_group(user_id: str, group_id: str, db: Session = Depends(get_user_db)):
user = (await db.session.execute(select(User).where(User.id == user_id))).scalars().first()
group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
group = (
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
)
if not user:
raise UserNotFoundError
@ -54,11 +70,11 @@ def get_permissions_router() -> APIRouter:
# Add association directly to the association table
stmt = insert(UserGroup).values(user_id=user_id, group_id=group_id)
await db.session.execute(stmt)
except IntegrityError as e:
except IntegrityError:
raise EntityAlreadyExistsError(message="User is already part of group.")
await db.session.commit()
return JSONResponse(status_code = 200, content = {"message": "User added to group"})
return JSONResponse(status_code=200, content={"message": "User added to group"})
return permissions_router

View file

@ -1,19 +1,21 @@
from cognee.modules.data.deletion import prune_system, prune_data
class prune():
class prune:
@staticmethod
async def prune_data():
await prune_data()
@staticmethod
async def prune_system(graph = True, vector = True, metadata = False):
async def prune_system(graph=True, vector=True, metadata=False):
await prune_system(graph, vector, metadata)
if __name__ == "__main__":
import asyncio
async def main():
await prune.prune_data()
await prune.prune_system()
asyncio.run(main())

View file

@ -2,8 +2,9 @@ from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
async def get_search_history(user: User = None) -> list:
if not user:
user = await get_default_user()
return await get_history(user.id)

View file

@ -1 +1 @@
from .get_search_router import get_search_router
from .get_search_router import get_search_router

View file

@ -13,6 +13,7 @@ class SearchPayloadDTO(InDTO):
search_type: SearchType
query: str
def get_search_router() -> APIRouter:
router = APIRouter()
@ -22,21 +23,18 @@ def get_search_router() -> APIRouter:
user: str
created_at: datetime
@router.get("/", response_model = list[SearchHistoryItem])
@router.get("/", response_model=list[SearchHistoryItem])
async def get_search_history(user: User = Depends(get_authenticated_user)):
try:
history = await get_history(user.id)
return history
except Exception as error:
return JSONResponse(
status_code = 500,
content = {"error": str(error)}
)
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("/", response_model = list)
@router.post("/", response_model=list)
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for searching for nodes in the graph."""
"""This endpoint is responsible for searching for nodes in the graph."""
from cognee.api.v1.search import search as cognee_search
try:
@ -44,9 +42,6 @@ def get_search_router() -> APIRouter:
return results
except Exception as error:
return JSONResponse(
status_code = 409,
content = {"error": str(error)}
)
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -1,4 +1,5 @@
""" This module contains the search function that is used to search for nodes in the graph."""
"""This module contains the search function that is used to search for nodes in the graph."""
import asyncio
from enum import Enum
from typing import Dict, Any, Callable, List
@ -16,6 +17,7 @@ from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
class SearchType(Enum):
ADJACENT = "ADJACENT"
TRAVERSE = "TRAVERSE"
@ -23,7 +25,7 @@ class SearchType(Enum):
SUMMARY = "SUMMARY"
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
DOCUMENT_CLASSIFICATION = "DOCUMENT_CLASSIFICATION",
DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",)
CYPHER = "CYPHER"
@staticmethod
@ -33,12 +35,13 @@ class SearchType(Enum):
except KeyError as error:
raise ValueError(f"{name} is not a valid SearchType") from error
class SearchParameters(BaseModel):
search_type: SearchType
params: Dict[str, Any]
@field_validator("search_type", mode="before")
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
if isinstance(value, str):
return SearchType.from_str(value)
return value
@ -52,7 +55,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
raise UserNotFoundError
own_document_ids = await get_document_ids_for_user(user.id)
search_params = SearchParameters(search_type = search_type, params = params)
search_params = SearchParameters(search_type=search_type, params=params)
search_results = await specific_search([search_params], user)
from uuid import UUID
@ -61,7 +64,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if type(document_id) == str else document_id
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)

View file

@ -16,14 +16,20 @@ from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
async def search(query_type: SearchType, query_text: str, user: User = None,
datasets: Union[list[str], str, None] = None) -> list:
async def search(
query_type: SearchType,
query_text: str,
user: User = None,
datasets: Union[list[str], str, None] = None,
) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
datasets = [datasets]
@ -43,15 +49,16 @@ async def search(query_type: SearchType, query_text: str, user: User = None,
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if type(document_id) == str else document_id
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
await log_result(query.id, json.dumps(filtered_search_results, cls = JSONEncoder), user.id)
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
return filtered_search_results
async def specific_search(query_type: SearchType, query: str, user) -> list:
search_tasks: Dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,

View file

@ -1 +1 @@
from .get_settings_router import get_settings_router
from .get_settings_router import get_settings_router

View file

@ -6,40 +6,50 @@ from fastapi import Depends
from cognee.modules.users.models import User
from cognee.modules.settings.get_settings import LLMConfig, VectorDBConfig
class LLMConfigOutputDTO(OutDTO, LLMConfig):
pass
class VectorDBConfigOutputDTO(OutDTO, VectorDBConfig):
pass
class SettingsDTO(OutDTO):
llm: LLMConfigOutputDTO
vector_db: VectorDBConfigOutputDTO
class LLMConfigInputDTO(InDTO):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str
api_key: str
class VectorDBConfigInputDTO(InDTO):
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
url: str
api_key: str
class SettingsPayloadDTO(InDTO):
llm: Optional[LLMConfigInputDTO] = None
vector_db: Optional[VectorDBConfigInputDTO] = None
def get_settings_router() -> APIRouter:
router = APIRouter()
@router.get("/", response_model=SettingsDTO)
async def get_settings(user: User = Depends(get_authenticated_user)):
from cognee.modules.settings import get_settings as get_cognee_settings
return get_cognee_settings()
@router.post("/", response_model=None)
async def save_settings(new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)):
async def save_settings(
new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)
):
from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None:
@ -48,4 +58,4 @@ def get_settings_router() -> APIRouter:
if new_settings.vector_db is not None:
await save_vector_db_config(new_settings.vector_db)
return router
return router

View file

@ -3,10 +3,10 @@ from cognee.modules.users.methods import create_user as create_user_method
async def create_user(email: str, password: str, is_superuser: bool = False):
user = await create_user_method(
email = email,
password = password,
is_superuser = is_superuser,
is_verified = True,
email=email,
password=password,
is_superuser=is_superuser,
is_verified=True,
)
return user

View file

@ -3,3 +3,4 @@ from .get_register_router import get_register_router
from .get_reset_password_router import get_reset_password_router
from .get_users_router import get_users_router
from .get_verify_router import get_verify_router
from .get_visualize_router import get_visualize_router

View file

@ -1,6 +1,7 @@
from cognee.modules.users.get_fastapi_users import get_fastapi_users
from cognee.modules.users.authentication.get_auth_backend import get_auth_backend
def get_auth_router():
auth_backend = get_auth_backend()
return get_fastapi_users().get_auth_router(auth_backend)

View file

@ -1,5 +1,6 @@
from cognee.modules.users.get_fastapi_users import get_fastapi_users
from cognee.modules.users.models.User import UserRead, UserCreate
def get_register_router():
return get_fastapi_users().get_register_router(UserRead, UserCreate)

View file

@ -1,4 +1,5 @@
from cognee.modules.users.get_fastapi_users import get_fastapi_users
def get_reset_password_router():
return get_fastapi_users().get_reset_password_router()

View file

@ -1,5 +1,6 @@
from cognee.modules.users.get_fastapi_users import get_fastapi_users
from cognee.modules.users.models.User import UserRead, UserUpdate
def get_users_router():
return get_fastapi_users().get_users_router(UserRead, UserUpdate)

View file

@ -1,5 +1,6 @@
from cognee.modules.users.get_fastapi_users import get_fastapi_users
from cognee.modules.users.models.User import UserRead
def get_verify_router():
return get_fastapi_users().get_verify_router(UserRead)

View file

@ -0,0 +1,32 @@
from fastapi import Form, UploadFile, Depends
from fastapi.responses import JSONResponse
from fastapi import APIRouter
from typing import List
import aiohttp
import subprocess
import logging
import os
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__)
def get_visualize_router() -> APIRouter:
router = APIRouter()
@router.post("/", response_model=None)
async def visualize(
user: User = Depends(get_authenticated_user),
):
"""This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.visualize import visualize_graph
try:
html_visualization = await visualize_graph()
return html_visualization
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -0,0 +1 @@
from .visualize import visualize_graph

View file

@ -0,0 +1,14 @@
from cognee.shared.utils import create_cognee_style_network_with_logo, graph_to_tuple
from cognee.infrastructure.databases.graph import get_graph_engine
import logging
async def visualize_graph(label: str = "name"):
""" """
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
return graph

View file

@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import MonitoringTool
class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
monitoring_tool: object = MonitoringTool.LANGFUSE
@ -13,8 +14,8 @@ class BaseConfig(BaseSettings):
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
@ -22,6 +23,7 @@ class BaseConfig(BaseSettings):
"monitoring_tool": self.monitoring_tool,
}
@lru_cache
def get_base_config():
return BaseConfig()

View file

@ -10,4 +10,4 @@ from .exceptions import (
ServiceError,
InvalidValueError,
InvalidAttributeError,
)
)

View file

@ -3,6 +3,7 @@ import logging
logger = logging.getLogger(__name__)
class CogneeApiError(Exception):
"""Base exception class"""
@ -36,19 +37,19 @@ class ServiceError(CogneeApiError):
class InvalidValueError(CogneeApiError):
def __init__(
self,
message: str = "Invalid Value.",
name: str = "InvalidValueError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
self,
message: str = "Invalid Value.",
name: str = "InvalidValueError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class InvalidAttributeError(CogneeApiError):
def __init__(
self,
message: str = "Invalid attribute.",
name: str = "InvalidAttributeError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
self,
message: str = "Invalid attribute.",
name: str = "InvalidAttributeError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
super().__init__(message, name, status_code)

View file

@ -15,7 +15,7 @@ sys.path.insert(0, parent_dir)
environment = os.getenv("AWS_ENV", "dev")
def fetch_secret(secret_name:str, region_name:str, env_file_path:str):
def fetch_secret(secret_name: str, region_name: str, env_file_path: str):
"""Fetch the secret from AWS Secrets Manager and write it to the .env file."""
print("Initializing session")
session = boto3.session.Session()

View file

@ -1,4 +1,5 @@
""" Chunking strategies for splitting text into smaller parts."""
"""Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
import re
from cognee.shared.data_models import ChunkStrategy
@ -6,17 +7,15 @@ from cognee.shared.data_models import ChunkStrategy
# /Users/vasa/Projects/cognee/cognee/infrastructure/data/chunking/DefaultChunkEngine.py
class DefaultChunkEngine():
class DefaultChunkEngine:
def __init__(self, chunk_strategy=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
@ -32,13 +31,12 @@ class DefaultChunkEngine():
splits = list(text)
return [s for s in splits if s != ""]
def chunk_data(self,
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
def chunk_data(
self,
chunk_strategy=None,
source_data=None,
chunk_size=None,
chunk_overlap=None,
):
"""
Chunk data based on the specified strategy.
@ -54,44 +52,47 @@ class DefaultChunkEngine():
"""
if self.chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data, chunk_number = self.chunk_data_by_paragraph(source_data,chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
chunked_data, chunk_number = self.chunk_data_by_paragraph(
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
elif self.chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data, chunk_number = self.chunk_by_sentence(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
chunked_data, chunk_number = self.chunk_by_sentence(
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
elif self.chunk_strategy == ChunkStrategy.EXACT:
chunked_data, chunk_number = self.chunk_data_exact(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
chunked_data, chunk_number = self.chunk_data_exact(
source_data, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
else:
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
return chunked_data, chunk_number
def chunk_data_exact(self, data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
chunks.append(data[i : i + chunk_size])
numbered_chunks = []
for i, chunk in enumerate(chunks):
numbered_chunk = [i + 1, chunk]
numbered_chunks.append(numbered_chunk)
return chunks, numbered_chunks
def chunk_by_sentence(self, data_chunks, chunk_size, chunk_overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentence_endings = r"(?<=[.!?…]) +"
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = self.chunk_data_exact(data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = self.chunk_data_exact(
data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
@ -102,9 +103,7 @@ class DefaultChunkEngine():
numbered_chunks.append(numbered_chunk)
return sentence_chunks, numbered_chunks
def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound = 0.75):
def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound=0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []

View file

@ -5,7 +5,6 @@ from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkE
from cognee.shared.data_models import ChunkStrategy
class LangchainChunkEngine:
def __init__(self, chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
@ -13,13 +12,12 @@ class LangchainChunkEngine:
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk_data(self,
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
def chunk_data(
self,
chunk_strategy=None,
source_data=None,
chunk_size=None,
chunk_overlap=None,
):
"""
Chunk data based on the specified strategy.
@ -35,20 +33,24 @@ class LangchainChunkEngine:
"""
if chunk_strategy == ChunkStrategy.CODE:
chunked_data, chunk_number = self.chunk_data_by_code(source_data,self.chunk_size, self.chunk_overlap)
chunked_data, chunk_number = self.chunk_data_by_code(
source_data, self.chunk_size, self.chunk_overlap
)
elif chunk_strategy == ChunkStrategy.LANGCHAIN_CHARACTER:
chunked_data, chunk_number = self.chunk_data_by_character(source_data,self.chunk_size, self.chunk_overlap)
chunked_data, chunk_number = self.chunk_data_by_character(
source_data, self.chunk_size, self.chunk_overlap
)
else:
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
return chunked_data, chunk_number
def chunk_data_by_code(self, data_chunks, chunk_size, chunk_overlap= 10, language=None):
def chunk_data_by_code(self, data_chunks, chunk_size, chunk_overlap=10, language=None):
from langchain_text_splitters import (
Language,
RecursiveCharacterTextSplitter,
)
if language is None:
language = Language.PYTHON
python_splitter = RecursiveCharacterTextSplitter.from_language(
@ -67,7 +69,10 @@ class LangchainChunkEngine:
def chunk_data_by_character(self, data_chunks, chunk_size=1500, chunk_overlap=10):
from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size =chunk_size, chunk_overlap=chunk_overlap)
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
data_chunks = splitter.create_documents([data_chunks])
only_content = [chunk.page_content for chunk in data_chunks]
@ -78,4 +83,3 @@ class LangchainChunkEngine:
numbered_chunks.append(numbered_chunk)
return only_content, numbered_chunks

View file

@ -11,8 +11,7 @@ class ChunkConfig(BaseSettings):
chunk_strategy: object = ChunkStrategy.PARAGRAPH
chunk_engine: object = ChunkEngine.DEFAULT_ENGINE
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {

View file

@ -9,9 +9,11 @@ class ChunkingConfig(Dict):
vector_db_key: str
vector_db_provider: str
def create_chunking_engine(config: ChunkingConfig):
if config["chunk_engine"] == ChunkEngine.LANGCHAIN_ENGINE:
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
return LangchainChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],

View file

@ -2,5 +2,6 @@ from .config import get_chunk_config
from .create_chunking_engine import create_chunking_engine
def get_chunk_engine():
return create_chunking_engine(get_chunk_config().to_dict())

View file

@ -3,6 +3,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from cognee.exceptions import InvalidValueError
from cognee.shared.utils import extract_pos_tags
def extract_keywords(text: str) -> list[str]:
if len(text) == 0:
raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.")
@ -14,9 +15,7 @@ def extract_keywords(text: str) -> list[str]:
tfidf = vectorizer.fit_transform(nouns)
top_nouns = sorted(
vectorizer.vocabulary_,
key = lambda x: tfidf[0, vectorizer.vocabulary_[x]],
reverse = True
vectorizer.vocabulary_, key=lambda x: tfidf[0, vectorizer.vocabulary_[x]], reverse=True
)
keywords = []

View file

@ -1,3 +1,4 @@
class EmbeddingException(Exception):
"""Custom exception for handling embedding-related errors."""
pass
pass

View file

@ -7,4 +7,4 @@ This module defines a set of exceptions for handling various database errors
from .exceptions import (
EntityNotFoundError,
EntityAlreadyExistsError,
)
)

View file

@ -1,6 +1,7 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class EntityNotFoundError(CogneeApiError):
"""Database returns nothing"""
@ -22,4 +23,4 @@ class EntityAlreadyExistsError(CogneeApiError):
name: str = "EntityAlreadyExistsError",
status_code=status.HTTP_409_CONFLICT,
):
super().__init__(message, name, status_code)
super().__init__(message, name, status_code)

View file

@ -1,4 +1,4 @@
""" This module contains the configuration for the graph database. """
"""This module contains the configuration for the graph database."""
import os
from functools import lru_cache
@ -15,8 +15,7 @@ class GraphConfig(BaseSettings):
graph_database_password: str = ""
graph_database_port: int = 123
graph_file_path: str = os.path.join(
os.path.join(get_absolute_path(".cognee_system"), "databases"),
graph_filename
os.path.join(get_absolute_path(".cognee_system"), "databases"), graph_filename
)
graph_model: object = KnowledgeGraph
graph_topology: object = KnowledgeGraph

View file

@ -1,4 +1,5 @@
"""Factory function to get the appropriate graph client based on the graph type."""
from functools import lru_cache
from .config import get_graph_config
@ -26,7 +27,11 @@ def create_graph_engine() -> GraphDBInterface:
config = get_graph_config()
if config.graph_database_provider == "neo4j":
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
if not (
config.graph_database_url
and config.graph_database_username
and config.graph_database_password
):
raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
@ -34,7 +39,7 @@ def create_graph_engine() -> GraphDBInterface:
return Neo4jAdapter(
graph_database_url=config.graph_database_url,
graph_database_username=config.graph_database_username,
graph_database_password=config.graph_database_password
graph_database_password=config.graph_database_password,
)
elif config.graph_database_provider == "falkordb":
@ -53,6 +58,7 @@ def create_graph_engine() -> GraphDBInterface:
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename=config.graph_file_path)
return graph_client

View file

@ -1,47 +1,35 @@
from typing import Protocol, Optional, Dict, Any
from abc import abstractmethod
class GraphDBInterface(Protocol):
@abstractmethod
async def query(self, query: str, params: dict):
raise NotImplementedError
@abstractmethod
async def add_node(
self,
node_id: str,
node_properties: dict
): raise NotImplementedError
async def add_node(self, node_id: str, node_properties: dict):
raise NotImplementedError
@abstractmethod
async def add_nodes(
self,
nodes: list[tuple[str, dict]]
): raise NotImplementedError
async def add_nodes(self, nodes: list[tuple[str, dict]]):
raise NotImplementedError
@abstractmethod
async def delete_node(
self,
node_id: str
): raise NotImplementedError
async def delete_node(self, node_id: str):
raise NotImplementedError
@abstractmethod
async def delete_nodes(
self,
node_ids: list[str]
): raise NotImplementedError
async def delete_nodes(self, node_ids: list[str]):
raise NotImplementedError
@abstractmethod
async def extract_node(
self,
node_id: str
): raise NotImplementedError
async def extract_node(self, node_id: str):
raise NotImplementedError
@abstractmethod
async def extract_nodes(
self,
node_ids: list[str]
): raise NotImplementedError
async def extract_nodes(self, node_ids: list[str]):
raise NotImplementedError
@abstractmethod
async def add_edge(
@ -49,21 +37,20 @@ class GraphDBInterface(Protocol):
from_node: str,
to_node: str,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = None
): raise NotImplementedError
edge_properties: Optional[Dict[str, Any]] = None,
):
raise NotImplementedError
@abstractmethod
async def add_edges(
self,
edges: tuple[str, str, str, dict]
): raise NotImplementedError
async def add_edges(self, edges: tuple[str, str, str, dict]):
raise NotImplementedError
@abstractmethod
async def delete_graph(
self,
): raise NotImplementedError
):
raise NotImplementedError
@abstractmethod
async def get_graph_data(
self
): raise NotImplementedError
async def get_graph_data(self):
raise NotImplementedError

View file

@ -1,4 +1,5 @@
""" Neo4j Adapter for Graph Database"""
"""Neo4j Adapter for Graph Database"""
import logging
import asyncio
from textwrap import dedent
@ -13,6 +14,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInte
logger = logging.getLogger("Neo4jAdapter")
class Neo4jAdapter(GraphDBInterface):
def __init__(
self,
@ -23,8 +25,8 @@ class Neo4jAdapter(GraphDBInterface):
):
self.driver = driver or AsyncGraphDatabase.driver(
graph_database_url,
auth = (graph_database_username, graph_database_password),
max_connection_lifetime = 120
auth=(graph_database_username, graph_database_password),
max_connection_lifetime=120,
)
@asynccontextmanager
@ -39,11 +41,11 @@ class Neo4jAdapter(GraphDBInterface):
) -> List[Dict[str, Any]]:
try:
async with self.get_session() as session:
result = await session.run(query, parameters = params)
result = await session.run(query, parameters=params)
data = await result.data()
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def has_node(self, node_id: str) -> bool:
@ -53,7 +55,7 @@ class Neo4jAdapter(GraphDBInterface):
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
{"node_id": node_id}
{"node_id": node_id},
)
return results[0]["node_exists"] if len(results) > 0 else False
@ -83,15 +85,17 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
nodes = [{
"node_id": str(node.id),
"properties": self.serialize_properties(node.model_dump()),
} for node in nodes]
nodes = [
{
"node_id": str(node.id),
"properties": self.serialize_properties(node.model_dump()),
}
for node in nodes
]
results = await self.query(query, dict(nodes = nodes))
results = await self.query(query, dict(nodes=nodes))
return results
async def extract_node(self, node_id: str):
results = await self.extract_nodes([node_id])
@ -103,9 +107,7 @@ class Neo4jAdapter(GraphDBInterface):
MATCH (node {id: id})
RETURN node"""
params = {
"node_ids": node_ids
}
params = {"node_ids": node_ids}
results = await self.query(query, params)
@ -115,7 +117,7 @@ class Neo4jAdapter(GraphDBInterface):
node_id = id.replace(":", "_")
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
params = { "node_id": node_id }
params = {"node_id": node_id}
return await self.query(query, params)
@ -125,9 +127,7 @@ class Neo4jAdapter(GraphDBInterface):
MATCH (node {id: id})
DETACH DELETE node"""
params = {
"node_ids": node_ids
}
params = {"node_ids": node_ids}
return await self.query(query, params)
@ -157,21 +157,29 @@ class Neo4jAdapter(GraphDBInterface):
try:
params = {
"edges": [{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
} for edge in edges],
"edges": [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
}
for edge in edges
],
}
results = await self.query(query, params)
return [result["edge_exists"] for result in results]
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
async def add_edge(
self,
from_node: UUID,
to_node: UUID,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {},
):
serialized_properties = self.serialize_properties(edge_properties)
query = dedent("""MATCH (from_node {id: $from_node}),
@ -186,12 +194,11 @@ class Neo4jAdapter(GraphDBInterface):
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties
"properties": serialized_properties,
}
return await self.query(query, params)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
query = """
UNWIND $edges AS edge
@ -201,22 +208,25 @@ class Neo4jAdapter(GraphDBInterface):
RETURN rel
"""
edges = [{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
},
} for edge in edges]
edges = [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
},
}
for edge in edges
]
try:
results = await self.query(query, dict(edges = edges))
results = await self.query(query, dict(edges=edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def get_edges(self, node_id: str):
@ -225,9 +235,12 @@ class Neo4jAdapter(GraphDBInterface):
RETURN n, r, m
"""
results = await self.query(query, dict(node_id = node_id))
results = await self.query(query, dict(node_id=node_id))
return [(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]}) for result in results]
return [
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
for result in results
]
async def get_disconnected_nodes(self) -> list[str]:
# return await self.query(
@ -267,7 +280,6 @@ class Neo4jAdapter(GraphDBInterface):
results = await self.query(query)
return results[0]["ids"] if len(results) > 0 else []
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
@ -279,9 +291,9 @@ class Neo4jAdapter(GraphDBInterface):
results = await self.query(
query,
dict(
node_id = node_id,
edge_label = edge_label,
)
node_id=node_id,
edge_label=edge_label,
),
)
return [result["predecessor"] for result in results]
@ -295,8 +307,8 @@ class Neo4jAdapter(GraphDBInterface):
results = await self.query(
query,
dict(
node_id = node_id,
)
node_id=node_id,
),
)
return [result["predecessor"] for result in results]
@ -312,8 +324,8 @@ class Neo4jAdapter(GraphDBInterface):
results = await self.query(
query,
dict(
node_id = node_id,
edge_label = edge_label,
node_id=node_id,
edge_label=edge_label,
),
)
@ -328,14 +340,16 @@ class Neo4jAdapter(GraphDBInterface):
results = await self.query(
query,
dict(
node_id = node_id,
)
node_id=node_id,
),
)
return [result["successor"] for result in results]
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
predecessors, successors = await asyncio.gather(self.get_predecessors(node_id), self.get_successors(node_id))
predecessors, successors = await asyncio.gather(
self.get_predecessors(node_id), self.get_successors(node_id)
)
return predecessors + successors
@ -352,52 +366,55 @@ class Neo4jAdapter(GraphDBInterface):
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = str(node_id))),
self.query(predecessors_query, dict(node_id=str(node_id))),
self.query(successors_query, dict(node_id=str(node_id))),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
return connections
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
DELETE r;
"""
params = { "node_ids": node_ids }
params = {"node_ids": node_ids}
return await self.query(query, params)
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
DELETE r;
"""
params = { "node_ids": node_ids }
params = {"node_ids": node_ids}
return await self.query(query, params)
async def delete_graph(self):
query = """MATCH (node)
DETACH DELETE node;"""
return await self.query(query)
def serialize_properties(self, properties = dict()):
def serialize_properties(self, properties=dict()):
serialized_properties = {}
for property_key, property_value in properties.items():
@ -414,22 +431,28 @@ class Neo4jAdapter(GraphDBInterface):
result = await self.query(query)
nodes = [(
record["properties"]["id"],
record["properties"],
) for record in result]
nodes = [
(
record["properties"]["id"],
record["properties"],
)
for record in result
]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = [(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
) for record in result]
edges = [
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
for record in result
]
return (nodes, edges)
@ -446,7 +469,9 @@ class Neo4jAdapter(GraphDBInterface):
"""
where_clauses = []
for attribute, values in attribute_filters[0].items():
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
values_str = ", ".join(
f"'{value}'" if isinstance(value, str) else str(value) for value in values
)
where_clauses.append(f"n.{attribute} IN [{values_str}]")
where_clause = " AND ".join(where_clauses)
@ -458,10 +483,13 @@ class Neo4jAdapter(GraphDBInterface):
"""
result_nodes = await self.query(query_nodes)
nodes = [(
record["id"],
record["properties"],
) for record in result_nodes]
nodes = [
(
record["id"],
record["properties"],
)
for record in result_nodes
]
query_edges = f"""
MATCH (n)-[r]->(m)
@ -470,11 +498,14 @@ class Neo4jAdapter(GraphDBInterface):
"""
result_edges = await self.query(query_edges)
edges = [(
record["source"],
record["target"],
record["type"],
record["properties"],
) for record in result_edges]
edges = [
(
record["source"],
record["target"],
record["type"],
record["properties"],
)
for record in result_edges
]
return (nodes, edges)
return (nodes, edges)

View file

@ -17,9 +17,10 @@ from cognee.modules.storage.utils import JSONEncoder
logger = logging.getLogger("NetworkXAdapter")
class NetworkXAdapter(GraphDBInterface):
_instance = None
graph = None # Class variable to store the singleton instance
graph = None # Class variable to store the singleton instance
def __new__(cls, filename):
if cls._instance is None:
@ -27,12 +28,12 @@ class NetworkXAdapter(GraphDBInterface):
cls._instance.filename = filename
return cls._instance
def __init__(self, filename = "cognee_graph.pkl"):
def __init__(self, filename="cognee_graph.pkl"):
self.filename = filename
async def get_graph_data(self):
await self.load_graph_from_file()
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
return (list(self.graph.nodes(data=True)), list(self.graph.edges(data=True, keys=True)))
async def query(self, query: str, params: dict):
pass
@ -57,24 +58,21 @@ class NetworkXAdapter(GraphDBInterface):
self.graph.add_nodes_from(nodes)
await self.save_graph_to_file(self.filename)
async def get_graph(self):
return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label)
return self.graph.has_edge(from_node, to_node, key=edge_label)
async def has_edges(self, edges):
result = []
for (from_node, to_node, edge_label) in edges:
for from_node, to_node, edge_label in edges:
if self.graph.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label))
return result
async def add_edge(
self,
from_node: str,
@ -83,24 +81,38 @@ class NetworkXAdapter(GraphDBInterface):
edge_properties: Dict[str, Any] = {},
) -> None:
edge_properties["updated_at"] = datetime.now(timezone.utc)
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
self.graph.add_edge(
from_node,
to_node,
key=relationship_name,
**(edge_properties if edge_properties else {}),
)
await self.save_graph_to_file(self.filename)
async def add_edges(
self,
edges: tuple[str, str, str, dict],
) -> None:
edges = [(edge[0], edge[1], edge[2], {
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
}) for edge in edges]
edges = [
(
edge[0],
edge[1],
edge[2],
{
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
},
)
for edge in edges
]
self.graph.add_edges_from(edges)
await self.save_graph_to_file(self.filename)
async def get_edges(self, node_id: str):
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
return list(self.graph.in_edges(node_id, data=True)) + list(
self.graph.out_edges(node_id, data=True)
)
async def delete_node(self, node_id: str) -> None:
"""Asynchronously delete a node from the graph if it exists."""
@ -112,12 +124,11 @@ class NetworkXAdapter(GraphDBInterface):
self.graph.remove_nodes_from(node_ids)
await self.save_graph_to_file(self.filename)
async def get_disconnected_nodes(self) -> List[str]:
connected_components = list(nx.weakly_connected_components(self.graph))
disconnected_nodes = []
biggest_subgraph = max(connected_components, key = len)
biggest_subgraph = max(connected_components, key=len)
for component in connected_components:
if component != biggest_subgraph:
@ -125,7 +136,6 @@ class NetworkXAdapter(GraphDBInterface):
return disconnected_nodes
async def extract_node(self, node_id: str) -> dict:
if self.graph.has_node(node_id):
return self.graph.nodes[node_id]
@ -139,8 +149,8 @@ class NetworkXAdapter(GraphDBInterface):
if self.graph.has_node(node_id):
if edge_label is None:
return [
self.graph.nodes[predecessor] for predecessor \
in list(self.graph.predecessors(node_id))
self.graph.nodes[predecessor]
for predecessor in list(self.graph.predecessors(node_id))
]
nodes = []
@ -155,8 +165,8 @@ class NetworkXAdapter(GraphDBInterface):
if self.graph.has_node(node_id):
if edge_label is None:
return [
self.graph.nodes[successor] for successor \
in list(self.graph.successors(node_id))
self.graph.nodes[successor]
for successor in list(self.graph.successors(node_id))
]
nodes = []
@ -210,7 +220,9 @@ class NetworkXAdapter(GraphDBInterface):
return connections
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
for predecessor_id in list(self.graph.predecessors(node_id)):
@ -219,7 +231,9 @@ class NetworkXAdapter(GraphDBInterface):
await self.save_graph_to_file(self.filename)
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
for successor_id in list(self.graph.successors(node_id)):
@ -228,7 +242,7 @@ class NetworkXAdapter(GraphDBInterface):
await self.save_graph_to_file(self.filename)
async def save_graph_to_file(self, file_path: str=None) -> None:
async def save_graph_to_file(self, file_path: str = None) -> None:
"""Asynchronously save the graph to a file in JSON format."""
if not file_path:
file_path = self.filename
@ -236,8 +250,7 @@ class NetworkXAdapter(GraphDBInterface):
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
async with aiofiles.open(file_path, "w") as file:
await file.write(json.dumps(graph_data, cls = JSONEncoder))
await file.write(json.dumps(graph_data, cls=JSONEncoder))
async def load_graph_from_file(self, file_path: str = None):
"""Asynchronously load the graph from a file in JSON format."""
@ -252,50 +265,59 @@ class NetworkXAdapter(GraphDBInterface):
graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
try:
node["id"] = UUID(node["id"])
except:
pass
node["id"] = UUID(node["id"])
except Exception as e:
print(e)
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
node["updated_at"] = datetime.strptime(
node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
)
for edge in graph_data["links"]:
try:
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except Exception as e:
print(e)
pass
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
edge["updated_at"] = datetime.strptime(
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
)
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
for node_id, node_data in self.graph.nodes(data=True):
node_data['id'] = node_id
node_data["id"] = node_id
else:
# Log that the file does not exist and an empty graph is initialized
logger.warning("File %s not found. Initializing an empty graph.", file_path)
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
self.graph = (
nx.MultiDiGraph()
) # Use MultiDiGraph to keep it consistent with __init__
file_dir = os.path.dirname(file_path)
if not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok = True)
os.makedirs(file_dir, exist_ok=True)
await self.save_graph_to_file(file_path)
except Exception:
logger.error("Failed to load graph from file: %s", file_path)
async def delete_graph(self, file_path: str = None):
"""Asynchronously delete the graph file from the filesystem."""
if file_path is None:
file_path = self.filename # Assuming self.filename is defined elsewhere and holds the default graph file path
file_path = (
self.filename
) # Assuming self.filename is defined elsewhere and holds the default graph file path
try:
if os.path.exists(file_path):
await aiofiles_os.remove(file_path)
@ -305,7 +327,9 @@ class NetworkXAdapter(GraphDBInterface):
except Exception as error:
logger.error("Failed to delete graph: %s", error)
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
"""
Fetches nodes and relationships filtered by specified attribute values.
@ -325,18 +349,21 @@ class NetworkXAdapter(GraphDBInterface):
# Filter nodes
filtered_nodes = [
(node, data) for node, data in self.graph.nodes(data=True)
(node, data)
for node, data in self.graph.nodes(data=True)
if all(data.get(attr) in values for attr, values in where_clauses)
]
# Filter edges where both source and target nodes satisfy the filters
filtered_edges = [
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
(source, target, data.get("relationship_type", "UNKNOWN"), data)
for source, target, data in self.graph.edges(data=True)
if (
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses)
and all(
self.graph.nodes[target].get(attr) in values for attr, values in where_clauses
)
)
]
return filtered_nodes, filtered_edges
return filtered_nodes, filtered_edges

View file

@ -1,38 +1,36 @@
import asyncio
# from datetime import datetime
import json
from textwrap import dedent
from uuid import UUID
from webbrowser import Error
from falkordb import FalkorDB
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.graph.graph_db_interface import \
GraphDBInterface
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import \
VectorDBInterface
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
from cognee.infrastructure.engine import DataPoint
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"],
"type": "IndexSchema"
}
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
def __init__(
self,
database_url: str,
database_port: int,
embedding_engine = EmbeddingEngine,
embedding_engine=EmbeddingEngine,
):
self.driver = FalkorDB(
host = database_url,
port = database_port,
host=database_url,
port=database_port,
)
self.embedding_engine = embedding_engine
self.graph_name = "cognee_graph"
@ -56,7 +54,11 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return f"'{str(value)}'"
if type(value) is int or type(value) is float:
return value
if type(value) is list and type(value[0]) is float and len(value) == self.embedding_engine.get_vector_size():
if (
type(value) is list
and type(value[0]) is float
and len(value) == self.embedding_engine.get_vector_size()
):
return f"'vecf32({value})'"
# if type(value) is datetime:
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
@ -70,14 +72,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
node_label = type(data_point).__tablename__
property_names = DataPoint.get_embeddable_property_names(data_point)
node_properties = await self.stringify_properties({
**data_point.model_dump(),
**({
property_names[index]: (vectorized_values[index] \
if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
node_properties = await self.stringify_properties(
{
**data_point.model_dump(),
**(
{
property_names[index]: (
vectorized_values[index]
if index < len(vectorized_values)
else getattr(data_point, property_name, None)
)
for index, property_name in enumerate(property_names)
}),
})
}
),
}
)
return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
@ -129,12 +138,13 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
await self.create_data_point_query(
data_point,
[
vectorized_values[vector_map[str(data_point.id)][property_name]] \
if vector_map[str(data_point.id)][property_name] is not None \
else None \
vectorized_values[vector_map[str(data_point.id)][property_name]]
if vector_map[str(data_point.id)][property_name] is not None
else None
for property_name in DataPoint.get_embeddable_property_names(data_point)
],
) for data_point in data_points
)
for data_point in data_points
]
for query in queries:
@ -144,17 +154,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
graph = self.driver.select_graph(self.graph_name)
if not self.has_vector_index(graph, index_name, index_property_name):
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
graph.create_node_vector_index(
index_name, index_property_name, dim=self.embedding_engine.get_vector_size()
)
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
try:
indices = graph.list_indices()
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
except:
return any(
[
(index[0] == index_name and index_property_name in index[1])
for index in indices.result_set
]
)
except Error as e:
print(e)
return False
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
pass
async def add_node(self, node: DataPoint):
@ -183,11 +203,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
""").strip()
params = {
"edges": [{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
} for edge in edges],
"edges": [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
}
for edge in edges
],
}
results = self.query(query, params).result_set
@ -196,7 +219,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def retrieve(self, data_point_ids: list[UUID]):
result = self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": [str(data_point) for data_point in data_point_ids],
},
@ -224,19 +247,19 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
self.query(predecessors_query, dict(node_id=node_id)),
self.query(successors_query, dict(node_id=node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
return connections
@ -279,12 +302,15 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def get_graph_data(self):
@ -292,28 +318,34 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
result = self.query(query)
nodes = [(
record[2]["id"],
record[2],
) for record in result.result_set]
nodes = [
(
record[2]["id"],
record[2],
)
for record in result.result_set
]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = self.query(query)
edges = [(
record[3]["source_node_id"],
record[3]["target_node_id"],
record[2],
record[3],
) for record in result.result_set]
edges = [
(
record[3]["source_node_id"],
record[3]["target_node_id"],
record[2],
record[3],
)
for record in result.result_set
]
return (nodes, edges)
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": [str(data_point) for data_point in data_point_ids],
},

View file

@ -1,4 +1,5 @@
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass

View file

@ -4,16 +4,17 @@ from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.root_dir import get_absolute_path
class RelationalConfig(BaseSettings):
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
db_name: str = "cognee_db"
db_host: Union[str, None] = None # "localhost"
db_port: Union[str, None] = None # "5432"
db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee"
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
db_name: str = "cognee_db"
db_host: Union[str, None] = None # "localhost"
db_port: Union[str, None] = None # "5432"
db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee"
db_provider: str = "sqlite"
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
@ -26,6 +27,7 @@ class RelationalConfig(BaseSettings):
"db_provider": self.db_provider,
}
@lru_cache
def get_relational_config():
return RelationalConfig()

View file

@ -2,6 +2,7 @@ from cognee.infrastructure.files.storage import LocalStorage
from .ModelBase import Base
from .get_relational_engine import get_relational_engine, get_relational_config
async def create_db_and_tables():
relational_config = get_relational_config()
relational_engine = get_relational_engine()

View file

@ -1,5 +1,6 @@
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
def create_relational_engine(
db_path: str,
db_name: str,
@ -13,6 +14,8 @@ def create_relational_engine(
connection_string = f"sqlite+aiosqlite:///{db_path}/{db_name}"
if db_provider == "postgres":
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
connection_string = (
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
)
return SQLAlchemyAdapter(connection_string)

View file

@ -3,8 +3,9 @@
from .config import get_relational_config
from .create_relational_engine import create_relational_engine
# @lru_cache
def get_relational_engine():
relational_config = get_relational_config()
return create_relational_engine(**relational_config.to_dict())
return create_relational_engine(**relational_config.to_dict())

View file

@ -1,6 +1,6 @@
import os
from os import path
import logging
import logging
from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
@ -18,7 +18,8 @@ from ..ModelBase import Base
logger = logging.getLogger(__name__)
class SQLAlchemyAdapter():
class SQLAlchemyAdapter:
def __init__(self, connection_string: str):
self.db_path: str = None
self.db_uri: str = connection_string
@ -58,17 +59,23 @@ class SQLAlchemyAdapter():
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
async with self.engine.begin() as connection:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
await connection.execute(
text(
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
)
)
await connection.close()
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# SQLite doesnt support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
else:
await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;"))
await connection.execute(
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
)
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
columns = ", ".join(data[0].keys())
@ -94,7 +101,9 @@ class SQLAlchemyAdapter():
return [schema[0] for schema in result.fetchall()]
return []
async def delete_entity_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
async def delete_entity_by_id(
self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"
):
"""
Delete entity in given table based on id. Table must have an id Column.
"""
@ -114,7 +123,6 @@ class SQLAlchemyAdapter():
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
async def delete_data_entity(self, data_id: UUID):
"""
Delete data and local files related to data if there are no references to it anymore.
@ -131,14 +139,19 @@ class SQLAlchemyAdapter():
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
# Check if other data objects point to the same raw data location
raw_data_location_entities = (await session.execute(
select(Data.raw_data_location).where(Data.raw_data_location == data_entity.raw_data_location))).all()
raw_data_location_entities = (
await session.execute(
select(Data.raw_data_location).where(
Data.raw_data_location == data_entity.raw_data_location
)
)
).all()
# Don't delete local file unless this is the only reference to the file in the database
if len(raw_data_location_entities) == 1:
# delete local file only if it's created by cognee
from cognee.base_config import get_base_config
config = get_base_config()
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
@ -198,15 +211,18 @@ class SQLAlchemyAdapter():
metadata.clear()
return table_names
async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.begin() as connection:
query = f"SELECT * FROM {table_name}"
if filters:
filter_conditions = " AND ".join([
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})" if isinstance(value, list)
else f"{key} = :{key}" for key, value in filters.items()
])
filter_conditions = " AND ".join(
[
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
if isinstance(value, list)
else f"{key} = :{key}"
for key, value in filters.items()
]
)
query += f" WHERE {filter_conditions};"
query = text(query)
results = await connection.execute(query, filters)
@ -252,7 +268,6 @@ class SQLAlchemyAdapter():
except Exception as e:
print(f"Error dropping database tables: {e}")
async def create_database(self):
if self.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
@ -264,7 +279,6 @@ class SQLAlchemyAdapter():
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(Base.metadata.create_all)
async def delete_database(self):
try:
if self.engine.dialect.name == "sqlite":
@ -281,7 +295,9 @@ class SQLAlchemyAdapter():
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
drop_table_query = text(f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE")
drop_table_query = text(
f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE"
)
await connection.execute(drop_table_query)
metadata.clear()
except Exception as e:

View file

@ -3,16 +3,16 @@ from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.root_dir import get_absolute_path
class VectorConfig(BaseSettings):
vector_db_url: str = os.path.join(
os.path.join(get_absolute_path(".cognee_system"), "databases"),
"cognee.lancedb"
os.path.join(get_absolute_path(".cognee_system"), "databases"), "cognee.lancedb"
)
vector_db_port: int = 1234
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
@ -22,6 +22,7 @@ class VectorConfig(BaseSettings):
"vector_db_provider": self.vector_db_provider,
}
@lru_cache
def get_vectordb_config():
return VectorConfig()

View file

@ -16,9 +16,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter(
config["vector_db_url"],
config["vector_db_key"],
embedding_engine=embedding_engine
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
)
elif config["vector_db_provider"] == "qdrant":
@ -30,10 +28,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return QDrantAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine
embedding_engine=embedding_engine,
)
elif config['vector_db_provider'] == 'milvus':
elif config["vector_db_provider"] == "milvus":
from .milvus.MilvusAdapter import MilvusAdapter
if not config["vector_db_url"]:
@ -41,11 +39,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return MilvusAdapter(
url=config["vector_db_url"],
api_key=config['vector_db_key'],
embedding_engine=embedding_engine
api_key=config["vector_db_key"],
embedding_engine=embedding_engine,
)
elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config

View file

@ -1,5 +1,6 @@
from typing import Protocol
class EmbeddingEngine(Protocol):
async def embed_text(self, text: list[str]) -> list[list[float]]:
raise NotImplementedError()

View file

@ -43,14 +43,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def exponential_backoff(attempt):
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time)
try:
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
self.retry_count = 0
@ -61,7 +59,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version
api_version=self.api_version,
)
self.retry_count = 0
@ -73,7 +71,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
if len(text) == 1:
parts = [text]
else:
parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]]
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
@ -89,7 +87,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
except litellm.exceptions.RateLimitError:
if self.retry_count >= self.MAX_RETRIES:
raise Exception(f"Rate limit exceeded and no more retries left.")
raise Exception("Rate limit exceeded and no more retries left.")
await exponential_backoff(self.retry_count)

Some files were not shown because too many files have changed in this diff Show more