Merge remote-tracking branch 'origin/dev'
This commit is contained in:
commit
a01d153971
161 changed files with 7350 additions and 5293 deletions
|
|
@ -1,12 +1,28 @@
|
|||
ENV="local"
|
||||
TOKENIZERS_PARALLELISM="false"
|
||||
LLM_API_KEY=
|
||||
|
||||
# LLM Configuration
|
||||
LLM_API_KEY=""
|
||||
LLM_MODEL="openai/gpt-4o-mini"
|
||||
LLM_PROVIDER="openai"
|
||||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
LLM_MAX_TOKENS="16384"
|
||||
|
||||
GRAPHISTRY_USERNAME=
|
||||
GRAPHISTRY_PASSWORD=
|
||||
|
||||
SENTRY_REPORTING_URL=
|
||||
|
||||
# Embedding Configuration
|
||||
EMBEDDING_PROVIDER="openai"
|
||||
EMBEDDING_API_KEY=""
|
||||
EMBEDDING_MODEL="openai/text-embedding-3-large"
|
||||
EMBEDDING_ENDPOINT=""
|
||||
EMBEDDING_API_VERSION=""
|
||||
EMBEDDING_DIMENSIONS=3072
|
||||
EMBEDDING_MAX_TOKENS=8191
|
||||
|
||||
# "neo4j" or "networkx"
|
||||
GRAPH_DATABASE_PROVIDER="networkx"
|
||||
# Not needed if using networkx
|
||||
|
|
|
|||
7
.github/pull_request_template.md
vendored
Normal file
7
.github/pull_request_template.md
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
<!-- .github/pull_request_template.md -->
|
||||
|
||||
## Description
|
||||
<!-- Provide a clear description of the changes in this PR -->
|
||||
|
||||
## DCO Affirmation
|
||||
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
|
||||
53
.github/workflows/approve_dco.yaml
vendored
Normal file
53
.github/workflows/approve_dco.yaml
vendored
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
name: DCO Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize, ready_for_review]
|
||||
|
||||
jobs:
|
||||
check-dco:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate Developer Certificate of Origin statement
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
# If using the built-in GITHUB_TOKEN, ensure it has 'read:org' permission.
|
||||
# In GitHub Enterprise or private orgs, you might need a PAT (personal access token) with read:org scope.
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const orgName = 'YOUR_ORGANIZATION_NAME'; // Replace with your org
|
||||
const prUser = context.payload.pull_request.user.login;
|
||||
const prBody = context.payload.pull_request.body || '';
|
||||
|
||||
// Exact text you require in the PR body
|
||||
const requiredStatement = "I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin";
|
||||
|
||||
// 1. Check if user is in the org
|
||||
let isOrgMember = false;
|
||||
try {
|
||||
// Attempt to get membership info
|
||||
const membership = await github.rest.orgs.getMembershipForUser({
|
||||
org: orgName,
|
||||
username: prUser,
|
||||
});
|
||||
// If we get here without an error, user is in the org
|
||||
isOrgMember = true;
|
||||
console.log(`${prUser} is a member of ${orgName}. Skipping DCO check.`);
|
||||
} catch (error) {
|
||||
// If we get a 404, user is NOT an org member
|
||||
if (error.status === 404) {
|
||||
console.log(`${prUser} is NOT a member of ${orgName}. Enforcing DCO check.`);
|
||||
} else {
|
||||
// Some other error—fail the workflow or handle accordingly
|
||||
core.setFailed(`Error checking organization membership: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. If user is not in the org, enforce the DCO statement
|
||||
if (!isOrgMember) {
|
||||
if (!prBody.includes(requiredStatement)) {
|
||||
core.setFailed(
|
||||
`DCO check failed. The PR body must include the following statement:\n\n${requiredStatement}`
|
||||
);
|
||||
}
|
||||
}
|
||||
1
.github/workflows/cd.yaml
vendored
1
.github/workflows/cd.yaml
vendored
|
|
@ -4,7 +4,6 @@ on:
|
|||
push:
|
||||
branches:
|
||||
- dev
|
||||
- feature/*
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'examples/**'
|
||||
|
|
|
|||
24
.github/workflows/clean_stale_pr.yaml
vendored
Normal file
24
.github/workflows/clean_stale_pr.yaml
vendored
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
name: clean | remove stale PRs
|
||||
|
||||
on:
|
||||
# Run this action periodically (daily at 0:00 UTC in this example).
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
# Optionally, also run when pull requests are labeled, unlabeled, synchronized, or reopened
|
||||
# to update the stale timer as needed. Uncomment if desired.
|
||||
# pull_request:
|
||||
# types: [labeled, unlabeled, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark and Close Stale
|
||||
uses: actions/stale@v6
|
||||
with:
|
||||
# Number of days of inactivity before the pull request is marked stale
|
||||
days-before-stale: 60
|
||||
# Number of days of inactivity after being marked stale before the pull request is closed
|
||||
days-before-close: 7
|
||||
# Comment to post when marking as stale
|
||||
stale-pr-message: "This pull request has been automatically marke
|
||||
67
.github/workflows/dockerhub.yml
vendored
67
.github/workflows/dockerhub.yml
vendored
|
|
@ -1,8 +1,9 @@
|
|||
name: build | Build and Push Docker Image to DockerHub
|
||||
name: build | Build and Push Docker Image to dockerhub
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
- main
|
||||
|
||||
jobs:
|
||||
|
|
@ -10,42 +11,38 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract Git information
|
||||
id: git-info
|
||||
run: |
|
||||
echo "BRANCH_NAME=${GITHUB_REF_NAME}" >> "$GITHUB_ENV"
|
||||
echo "COMMIT_SHA=${GITHUB_SHA::7}" >> "$GITHUB_ENV"
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: cognee/cognee
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=sha,prefix={{branch}}-
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Build and Push Docker Image
|
||||
run: |
|
||||
IMAGE_NAME=cognee/cognee
|
||||
TAG_VERSION="${BRANCH_NAME}-${COMMIT_SHA}"
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
|
||||
echo "Building image: ${IMAGE_NAME}:${TAG_VERSION}"
|
||||
docker buildx build \
|
||||
--platform linux/amd64,linux/arm64 \
|
||||
--push \
|
||||
--tag "${IMAGE_NAME}:${TAG_VERSION}" \
|
||||
--tag "${IMAGE_NAME}:latest" \
|
||||
.
|
||||
|
||||
- name: Verify pushed Docker images
|
||||
run: |
|
||||
# Verify both platform variants
|
||||
for PLATFORM in "linux/amd64" "linux/arm64"; do
|
||||
echo "Verifying image for $PLATFORM..."
|
||||
docker buildx imagetools inspect "${IMAGE_NAME}:${TAG_VERSION}" --format "{{.Manifest.$PLATFORM.Digest}}"
|
||||
done
|
||||
echo "Successfully verified images in Docker Hub"
|
||||
- name: Image digest
|
||||
run: echo ${{ steps.build.outputs.digest }}
|
||||
40
.github/workflows/profiling.yaml
vendored
40
.github/workflows/profiling.yaml
vendored
|
|
@ -68,32 +68,32 @@ jobs:
|
|||
echo "HEAD_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV
|
||||
|
||||
# Run profiler on the base branch
|
||||
- name: Run profiler on base branch
|
||||
env:
|
||||
BASE_SHA: ${{ env.BASE_SHA }}
|
||||
run: |
|
||||
echo "Profiling the base branch for code_graph_pipeline.py"
|
||||
echo "Checking out base SHA: $BASE_SHA"
|
||||
git checkout $BASE_SHA
|
||||
echo "This is the working directory: $PWD"
|
||||
# Ensure the script is executable
|
||||
chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# Run Scalene
|
||||
poetry run pyinstrument --renderer json -o base_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
|
||||
# Run profiler on head branch
|
||||
# - name: Run profiler on head branch
|
||||
# - name: Run profiler on base branch
|
||||
# env:
|
||||
# HEAD_SHA: ${{ env.HEAD_SHA }}
|
||||
# BASE_SHA: ${{ env.BASE_SHA }}
|
||||
# run: |
|
||||
# echo "Profiling the head branch for code_graph_pipeline.py"
|
||||
# echo "Checking out head SHA: $HEAD_SHA"
|
||||
# git checkout $HEAD_SHA
|
||||
# echo "Profiling the base branch for code_graph_pipeline.py"
|
||||
# echo "Checking out base SHA: $BASE_SHA"
|
||||
# git checkout $BASE_SHA
|
||||
# echo "This is the working directory: $PWD"
|
||||
# # Ensure the script is executable
|
||||
# chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# # Run Scalene
|
||||
# poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# poetry run pyinstrument --renderer json -o base_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
|
||||
# Run profiler on head branch
|
||||
- name: Run profiler on head branch
|
||||
env:
|
||||
HEAD_SHA: ${{ env.HEAD_SHA }}
|
||||
run: |
|
||||
echo "Profiling the head branch for code_graph_pipeline.py"
|
||||
echo "Checking out head SHA: $HEAD_SHA"
|
||||
git checkout $HEAD_SHA
|
||||
echo "This is the working directory: $PWD"
|
||||
# Ensure the script is executable
|
||||
chmod +x cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
# Run Scalene
|
||||
poetry run pyinstrument --renderer json -o head_results.json cognee/api/v1/cognify/code_graph_pipeline.py
|
||||
|
||||
# # Compare profiling results
|
||||
# - name: Compare profiling results
|
||||
|
|
|
|||
25
.github/workflows/reusable_notebook.yml
vendored
25
.github/workflows/reusable_notebook.yml
vendored
|
|
@ -12,8 +12,24 @@ on:
|
|||
required: true
|
||||
GRAPHISTRY_PASSWORD:
|
||||
required: true
|
||||
#LLM_MODEL:
|
||||
# required: true
|
||||
#LLM_ENDPOINT:
|
||||
# required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
#LLM_API_VERSION:
|
||||
# required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
|
|
@ -50,8 +66,15 @@ jobs:
|
|||
- name: Execute Jupyter Notebook
|
||||
env:
|
||||
ENV: 'dev'
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Use OpenAI Until a multimedia model is deployed and DeepEval support for other models is added
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
run: |
|
||||
|
|
|
|||
32
.github/workflows/reusable_python_example.yml
vendored
32
.github/workflows/reusable_python_example.yml
vendored
|
|
@ -7,12 +7,32 @@ on:
|
|||
description: "Location of example script to run"
|
||||
required: true
|
||||
type: string
|
||||
arguments:
|
||||
description: "Arguments for example script"
|
||||
required: false
|
||||
type: string
|
||||
secrets:
|
||||
GRAPHISTRY_USERNAME:
|
||||
required: true
|
||||
GRAPHISTRY_PASSWORD:
|
||||
required: true
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: false
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
|
||||
env:
|
||||
|
|
@ -50,7 +70,15 @@ jobs:
|
|||
env:
|
||||
ENV: 'dev'
|
||||
PYTHONFAULTHANDLER: 1
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
run: poetry run python ${{ inputs.example-location }}
|
||||
run: poetry run python ${{ inputs.example-location }} ${{ inputs.arguments }}
|
||||
|
|
|
|||
29
.github/workflows/test_code_graph_example.yml
vendored
Normal file
29
.github/workflows/test_code_graph_example.yml
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
name: test | code graph example
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_simple_example_test:
|
||||
uses: ./.github/workflows/reusable_python_example.yml
|
||||
with:
|
||||
example-location: ./examples/python/code_graph_example.py
|
||||
arguments: "--repo_path ./evals"
|
||||
secrets:
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
@ -15,6 +15,14 @@ jobs:
|
|||
with:
|
||||
notebook-location: notebooks/cognee_llama_index.ipynb
|
||||
secrets:
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,14 @@ jobs:
|
|||
with:
|
||||
notebook-location: notebooks/cognee_multimedia_demo.ipynb
|
||||
secrets:
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
9
.github/workflows/test_deduplication.yml
vendored
9
.github/workflows/test_deduplication.yml
vendored
|
|
@ -57,5 +57,12 @@ jobs:
|
|||
- name: Run deduplication test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_deduplication.py
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ jobs:
|
|||
with:
|
||||
example-location: ./examples/python/dynamic_steps_example.py
|
||||
secrets:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
49
.github/workflows/test_dynamic_steps_example_windows.yml
vendored
Normal file
49
.github/workflows/test_dynamic_steps_example_windows.yml
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
name: test
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_notebook_test_windows:
|
||||
name: windows-latest
|
||||
runs-on: windows-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install poetry
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
poetry install --no-interaction --all-extras
|
||||
|
||||
- name: Execute Python Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
PYTHONFAULTHANDLER: 1
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./examples/python/dynamic_steps_example.py
|
||||
|
|
@ -15,6 +15,14 @@ jobs:
|
|||
with:
|
||||
notebook-location: notebooks/llama_index_cognee_integration.ipynb
|
||||
secrets:
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
9
.github/workflows/test_milvus.yml
vendored
9
.github/workflows/test_milvus.yml
vendored
|
|
@ -47,7 +47,14 @@ jobs:
|
|||
- name: Run default basic pipeline
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_milvus.py
|
||||
|
||||
- name: Clean up disk space
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ jobs:
|
|||
with:
|
||||
example-location: ./examples/python/multimedia_example.py
|
||||
secrets:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Use OpenAI until we deploy models to handle multimedia
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
30
.github/workflows/test_multimetric_qa_eval_run.yaml
vendored
Normal file
30
.github/workflows/test_multimetric_qa_eval_run.yaml
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
name: test | multimetric qa eval run
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_multimetric_qa_eval_test:
|
||||
uses: ./.github/workflows/reusable_python_example.yml
|
||||
with:
|
||||
example-location: ./evals/multimetric_qa_eval_run.py
|
||||
arguments: "--params_file evals/qa_eval_parameters.json --out_dir dirname"
|
||||
secrets:
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Until we add support for azure for DeepEval
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
9
.github/workflows/test_neo4j.yml
vendored
9
.github/workflows/test_neo4j.yml
vendored
|
|
@ -43,7 +43,14 @@ jobs:
|
|||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
|
|
|
|||
8
.github/workflows/test_notebook.yml
vendored
8
.github/workflows/test_notebook.yml
vendored
|
|
@ -16,6 +16,14 @@ jobs:
|
|||
with:
|
||||
notebook-location: notebooks/cognee_demo.ipynb
|
||||
secrets:
|
||||
#LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
#LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
#LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
9
.github/workflows/test_pgvector.yml
vendored
9
.github/workflows/test_pgvector.yml
vendored
|
|
@ -58,5 +58,12 @@ jobs:
|
|||
- name: Run default PGVector
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_pgvector.py
|
||||
|
|
|
|||
13
.github/workflows/test_python_3_10.yml
vendored
13
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -42,6 +42,10 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
|
@ -52,7 +56,14 @@ jobs:
|
|||
- name: Run default basic pipeline
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_library.py
|
||||
|
||||
- name: Clean up disk space
|
||||
|
|
|
|||
14
.github/workflows/test_python_3_11.yml
vendored
14
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -44,6 +44,11 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
||||
|
|
@ -53,7 +58,14 @@ jobs:
|
|||
- name: Run default basic pipeline
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_library.py
|
||||
|
||||
- name: Clean up disk space
|
||||
|
|
|
|||
12
.github/workflows/test_python_3_12.yml
vendored
12
.github/workflows/test_python_3_12.yml
vendored
|
|
@ -43,6 +43,9 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
|
@ -53,7 +56,14 @@ jobs:
|
|||
- name: Run default basic pipeline
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_library.py
|
||||
|
||||
- name: Clean up disk space
|
||||
|
|
|
|||
9
.github/workflows/test_qdrant.yml
vendored
9
.github/workflows/test_qdrant.yml
vendored
|
|
@ -44,7 +44,14 @@ jobs:
|
|||
- name: Run default Qdrant
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
VECTOR_DB_URL: ${{ secrets.QDRANT_API_URL }}
|
||||
VECTOR_DB_KEY: ${{ secrets.QDRANT_API_KEY }}
|
||||
run: poetry run python ./cognee/tests/test_qdrant.py
|
||||
|
|
|
|||
9
.github/workflows/test_simple_example.yml
vendored
9
.github/workflows/test_simple_example.yml
vendored
|
|
@ -16,6 +16,13 @@ jobs:
|
|||
with:
|
||||
example-location: ./examples/python/simple_example.py
|
||||
secrets:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
9
.github/workflows/test_weaviate.yml
vendored
9
.github/workflows/test_weaviate.yml
vendored
|
|
@ -44,7 +44,14 @@ jobs:
|
|||
- name: Run default Weaviate
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
VECTOR_DB_URL: ${{ secrets.WEAVIATE_API_URL }}
|
||||
VECTOR_DB_KEY: ${{ secrets.WEAVIATE_API_KEY }}
|
||||
run: poetry run python ./cognee/tests/test_weaviate.py
|
||||
|
|
|
|||
|
|
@ -79,6 +79,9 @@ $ git config alias.cos "commit -s"
|
|||
|
||||
Will allow you to write git cos which will automatically sign-off your commit. By signing a commit you are agreeing to the DCO and agree that you will be banned from the topoteretes GitHub organisation and Discord server if you violate the DCO.
|
||||
|
||||
"When a commit is ready to be merged please use the following template to agree to our developer certificate of origin:
|
||||
'I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin'
|
||||
|
||||
We consider the following as violations to the DCO:
|
||||
|
||||
Signing the DCO with a fake name or pseudonym, if you are registered on GitHub or another platform with a fake name then you will not be able to contribute to topoteretes before updating your name;
|
||||
|
|
|
|||
32
Dockerfile_modal
Normal file
32
Dockerfile_modal
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV RUN_MODE=modal
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml poetry.lock /app/
|
||||
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry install --all-extras --no-root --without dev
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY README.md /app/README.md
|
||||
71
README.md
71
README.md
|
|
@ -12,7 +12,11 @@ We build for developers who need a reliable, production-ready data layer for AI
|
|||
|
||||
## What is cognee?
|
||||
|
||||
Cognee implements scalable, modular ECL (Extract, Cognify, Load) pipelines that allow you to interconnect and retrieve past conversations, documents, and audio transcriptions while reducing hallucinations, developer effort, and cost.
|
||||
Cognee implements scalable, modular ECL (Extract, Cognify, Load) pipelines that allow you to interconnect and retrieve past conversations, documents, and audio transcriptions while reducing hallucinations, developer effort, and cost.
|
||||
|
||||
Cognee merges graph and vector databases to uncover hidden relationships and new patterns in your data. You can automatically model, load and retrieve entities and objects representing your business domain and analyze their relationships, uncovering insights that neither vector stores nor graph stores alone can provide. Learn more about use-cases [here](https://docs.cognee.ai/use_cases)
|
||||
|
||||
|
||||
Try it in a Google Colab <a href="https://colab.research.google.com/drive/1g-Qnx6l_ecHZi0IOw23rg0qC4TYvEvWZ?usp=sharing">notebook</a> or have a look at our <a href="https://docs.cognee.ai">documentation</a>
|
||||
|
||||
If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
|
||||
|
|
@ -85,7 +89,7 @@ import os
|
|||
os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
|
||||
|
||||
```
|
||||
or
|
||||
or
|
||||
```
|
||||
import cognee
|
||||
cognee.config.set_llm_api_key("YOUR_OPENAI_API_KEY")
|
||||
|
|
@ -115,7 +119,7 @@ DB_PORT=5432
|
|||
DB_NAME=cognee_db
|
||||
DB_USERNAME=cognee
|
||||
DB_PASSWORD=cognee
|
||||
```
|
||||
```
|
||||
|
||||
### Simple example
|
||||
|
||||
|
|
@ -140,14 +144,14 @@ async def main():
|
|||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
|
||||
print("Adding text to cognee:")
|
||||
print(text.strip())
|
||||
print(text.strip())
|
||||
# Add the text, and make it available for cognify
|
||||
await cognee.add(text)
|
||||
print("Text added successfully.\n")
|
||||
|
||||
|
||||
|
||||
print("Running cognify to create knowledge graph...\n")
|
||||
print("Cognify process steps:")
|
||||
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||
|
|
@ -156,19 +160,19 @@ async def main():
|
|||
print("4. Adding data points: Storing the extracted chunks for processing.")
|
||||
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
|
||||
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
||||
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
print("Cognify process complete.\n")
|
||||
|
||||
|
||||
|
||||
query_text = 'Tell me about NLP'
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
SearchType.INSIGHTS, query_text=query_text
|
||||
)
|
||||
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
|
|
@ -212,15 +216,16 @@ Cognee supports a variety of tools and services for different operations:
|
|||
- **Language Models (LLMs)**: You can use either Anyscale or Ollama as your LLM provider.
|
||||
|
||||
- **Graph Stores**: In addition to NetworkX, Neo4j is also supported for graph storage.
|
||||
|
||||
|
||||
- **User management**: Create individual user graphs and manage permissions
|
||||
|
||||
## Demo
|
||||
|
||||
Check out our demo notebook [here](https://github.com/topoteretes/cognee/blob/main/notebooks/cognee_demo.ipynb)
|
||||
Check out our demo notebook [here](https://github.com/topoteretes/cognee/blob/main/notebooks/cognee_demo.ipynb) or watch the Youtube video bellow
|
||||
|
||||
|
||||
[<img src="https://i3.ytimg.com/vi/-ARUfIzhzC4/maxresdefault.jpg" width="100%">](https://www.youtube.com/watch?v=BDFt4xVPmro "Learn about cognee: 55")
|
||||
[<img src="https://img.youtube.com/vi/fI4hDzguN5k/maxresdefault.jpg" width="100%">](https://www.youtube.com/watch?v=fI4hDzguN5k "Learn about cognee: 55")
|
||||
|
||||
|
||||
|
||||
## Get Started
|
||||
|
|
@ -241,6 +246,28 @@ Please see the cognee [Development Guide](https://docs.cognee.ai/quickstart/) fo
|
|||
```bash
|
||||
pip install cognee
|
||||
```
|
||||
### Deployment at Scale (Modal)
|
||||
|
||||
Scale cognee in 4(+1) simple steps to handle enterprise workloads using [Modal](https://modal.com)'s GPU-powered infrastructure
|
||||
|
||||
**1. Install the modal python client**
|
||||
```bash
|
||||
pip install modal
|
||||
```
|
||||
**2. Create a free account on [Modal](https://modal.com)**
|
||||
|
||||
**3. Set Up Modal API Key**
|
||||
```bash
|
||||
modal token set --token-id TOKEN_ID --token-secret TOKEN_SECRET --profile=PROFILE
|
||||
modal profile activate PROFILE
|
||||
```
|
||||
**4. Run cognee example**
|
||||
|
||||
This simple example will deploy separate cognee instances building their own memory stores and answering a list of questions at scale.
|
||||
```bash
|
||||
modal run -d modal_deployment.py
|
||||
```
|
||||
**5. Change the modal_deploy script and develop your own AI memory at scale 🚀**
|
||||
|
||||
## 💫 Contributors
|
||||
|
||||
|
|
@ -258,13 +285,13 @@ pip install cognee
|
|||
|
||||
|
||||
|
||||
| Name | Type | Current state | Known Issues |
|
||||
|----------|--------------------|-------------------|--------------|
|
||||
| Qdrant | Vector | Stable ✅ | |
|
||||
| Weaviate | Vector | Stable ✅ | |
|
||||
| LanceDB | Vector | Stable ✅ | |
|
||||
| Neo4j | Graph | Stable ✅ | |
|
||||
| NetworkX | Graph | Stable ✅ | |
|
||||
| FalkorDB | Vector/Graph | Unstable ❌ | |
|
||||
| PGVector | Vector | Stable ✅ | |
|
||||
| Milvus | Vector | Stable ✅ | |
|
||||
| Name | Type | Current state (Mac/Linux) | Known Issues | Current state (Windows) | Known Issues |
|
||||
|----------|--------------------|---------------------------|--------------|-------------------------|--------------|
|
||||
| Qdrant | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| Weaviate | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| LanceDB | Vector | Stable ✅ | | Stable ✅ | |
|
||||
| Neo4j | Graph | Stable ✅ | | Stable ✅ | |
|
||||
| NetworkX | Graph | Stable ✅ | | Stable ✅ | |
|
||||
| FalkorDB | Vector/Graph | Stable ✅ | | Unstable ❌ | |
|
||||
| PGVector | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| Milvus | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
|
|
|
|||
|
|
@ -1,31 +1,19 @@
|
|||
# cognee MCP server
|
||||
|
||||
|
||||
|
||||
|
||||
### Installing Manually
|
||||
A MCP server project
|
||||
=======
|
||||
1. Clone the [cognee](www.github.com/topoteretes/cognee) repo
|
||||
|
||||
|
||||
1. Clone the [cognee](https://github.com/topoteretes/cognee) repo
|
||||
|
||||
2. Install dependencies
|
||||
|
||||
```
|
||||
pip install uv
|
||||
```
|
||||
```
|
||||
brew install postgresql
|
||||
```
|
||||
|
||||
```
|
||||
brew install rust
|
||||
brew install uv
|
||||
```
|
||||
|
||||
```jsx
|
||||
cd cognee-mcp
|
||||
uv sync --dev --all-extras
|
||||
uv sync --dev --all-extras --reinstall
|
||||
```
|
||||
|
||||
3. Activate the venv with
|
||||
|
|
@ -37,11 +25,17 @@ source .venv/bin/activate
|
|||
4. Add the new server to your Claude config:
|
||||
|
||||
The file should be located here: ~/Library/Application\ Support/Claude/
|
||||
```
|
||||
cd ~/Library/Application\ Support/Claude/
|
||||
```
|
||||
You need to create claude_desktop_config.json in this folder if it doesn't exist
|
||||
|
||||
Make sure to add your paths and LLM API key to the file bellow
|
||||
Use your editor of choice, for example Nano:
|
||||
```
|
||||
nano claude_desktop_config.json
|
||||
```
|
||||
|
||||
|
||||
```
|
||||
{
|
||||
"mcpServers": {
|
||||
"cognee": {
|
||||
|
|
@ -57,16 +51,7 @@ You need to create claude_desktop_config.json in this folder if it doesn't exist
|
|||
"TOKENIZERS_PARALLELISM": "false",
|
||||
"LLM_API_KEY": "sk-"
|
||||
}
|
||||
},
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/Users/{user}/Desktop",
|
||||
"/Users/{user}/Projects"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -83,3 +68,19 @@ npx -y @smithery/cli install cognee --client claude
|
|||
|
||||
Define cognify tool in server.py
|
||||
Restart your Claude desktop.
|
||||
|
||||
|
||||
To use debugger, run:
|
||||
```bash
|
||||
mcp dev src/server.py
|
||||
```
|
||||
Open inspector with timeout passed:
|
||||
```
|
||||
http://localhost:5173?timeout=120000
|
||||
```
|
||||
|
||||
To apply new changes while developing cognee you need to do:
|
||||
|
||||
1. `poetry lock` in cognee folder
|
||||
2. `uv sync --dev --all-extras --reinstall `
|
||||
3. `mcp dev src/server.py`
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
from . import server
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the package."""
|
||||
asyncio.run(server.main())
|
||||
|
||||
|
||||
# Optionally expose other important items at package level
|
||||
__all__ = ["main", "server"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
import cognee
|
||||
import mcp.server.stdio
|
||||
import mcp.types as types
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
|
||||
server = Server("cognee-mcp")
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
# keys_to_keep = ["chunk_index", "topological_rank", "cut_type", "id", "text"]
|
||||
# keyset = set(keys_to_keep) & node.keys()
|
||||
# return "Node(" + " ".join([key + ": " + str(node[key]) + "," for key in keyset]) + ")"
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
)
|
||||
|
||||
return f"Node({node_data})"
|
||||
|
||||
|
||||
def retrieved_edges_to_string(search_results):
|
||||
edge_strings = []
|
||||
for triplet in search_results:
|
||||
node1, edge, node2 = triplet
|
||||
relationship_type = edge["relationship_name"]
|
||||
edge_str = f"{node_to_string(node1)} {relationship_type} {node_to_string(node2)}"
|
||||
edge_strings.append(edge_str)
|
||||
return "\n".join(edge_strings)
|
||||
|
||||
|
||||
def load_class(model_file, model_name):
|
||||
model_file = os.path.abspath(model_file)
|
||||
spec = importlib.util.spec_from_file_location("graph_model", model_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
model_class = getattr(module, model_name)
|
||||
|
||||
return model_class
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
"""
|
||||
List available tools.
|
||||
Each tool specifies its arguments using JSON Schema validation.
|
||||
"""
|
||||
return [
|
||||
types.Tool(
|
||||
name="cognify",
|
||||
description="Build knowledge graph from the input text.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"graph_model_file": {"type": "string"},
|
||||
"graph_model_name": {"type": "string"},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="search",
|
||||
description="Search the knowledge graph.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="prune",
|
||||
description="Reset the knowledge graph.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
"""
|
||||
Handle tool execution requests.
|
||||
Tools can modify server state and notify clients of changes.
|
||||
"""
|
||||
if name == "cognify":
|
||||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
if not arguments:
|
||||
raise ValueError("Missing arguments")
|
||||
|
||||
text = arguments.get("text")
|
||||
|
||||
if ("graph_model_file" in arguments) and ("graph_model_name" in arguments):
|
||||
model_file = arguments.get("graph_model_file")
|
||||
model_name = arguments.get("graph_model_name")
|
||||
|
||||
graph_model = load_class(model_file, model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
await cognee.add(text)
|
||||
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text="Ingested",
|
||||
)
|
||||
]
|
||||
elif name == "search":
|
||||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
if not arguments:
|
||||
raise ValueError("Missing arguments")
|
||||
|
||||
search_query = arguments.get("query")
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
|
||||
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=results,
|
||||
)
|
||||
]
|
||||
elif name == "prune":
|
||||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text="Pruned",
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
||||
async def main():
|
||||
# Run the server using stdin/stdout streams
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="cognee-mcp",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This is needed if you'd like to connect to a custom client
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -4,74 +4,10 @@ version = "0.1.0"
|
|||
description = "A MCP server project"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"mcp>=1.1.1",
|
||||
"openai==1.59.4",
|
||||
"pydantic==2.8.2",
|
||||
"python-dotenv==1.0.1",
|
||||
"fastapi>=0.109.2,<0.110.0",
|
||||
"uvicorn==0.22.0",
|
||||
"requests==2.32.3",
|
||||
"aiohttp==3.10.10",
|
||||
"typing_extensions==4.12.2",
|
||||
"nest_asyncio==1.6.0",
|
||||
"numpy==1.26.4",
|
||||
"datasets==3.1.0",
|
||||
"falkordb==1.0.9", # Optional
|
||||
"boto3>=1.26.125,<2.0.0",
|
||||
"botocore>=1.35.54,<2.0.0",
|
||||
"gunicorn>=20.1.0,<21.0.0",
|
||||
"sqlalchemy==2.0.36",
|
||||
"instructor==1.7.2",
|
||||
"networkx>=3.2.1,<4.0.0",
|
||||
"aiosqlite>=0.20.0,<0.21.0",
|
||||
"pandas==2.2.3",
|
||||
"filetype>=1.2.0,<2.0.0",
|
||||
"nltk>=3.8.1,<4.0.0",
|
||||
"dlt[sqlalchemy]>=1.4.1,<2.0.0",
|
||||
"aiofiles>=23.2.1,<24.0.0",
|
||||
"qdrant-client>=1.9.0,<2.0.0", # Optional
|
||||
"graphistry>=0.33.5,<0.34.0",
|
||||
"tenacity>=9.0.0",
|
||||
"weaviate-client==4.6.7", # Optional
|
||||
"scikit-learn>=1.5.0,<2.0.0",
|
||||
"pypdf>=4.1.0,<5.0.0",
|
||||
"neo4j>=5.20.0,<6.0.0", # Optional
|
||||
"jinja2>=3.1.3,<4.0.0",
|
||||
"matplotlib>=3.8.3,<4.0.0",
|
||||
"tiktoken==0.7.0",
|
||||
"langchain_text_splitters==0.3.2", # Optional
|
||||
"langsmith==0.1.139", # Optional
|
||||
"langdetect==1.0.9",
|
||||
"posthog>=3.5.0,<4.0.0", # Optional
|
||||
"lancedb==0.16.0",
|
||||
"litellm==1.57.2",
|
||||
"groq==0.8.0", # Optional
|
||||
"langfuse>=2.32.0,<3.0.0", # Optional
|
||||
"pydantic-settings>=2.2.1,<3.0.0",
|
||||
"anthropic>=0.26.1,<1.0.0",
|
||||
"sentry-sdk[fastapi]>=2.9.0,<3.0.0",
|
||||
"fastapi-users[sqlalchemy]", # Optional
|
||||
"alembic>=1.13.3,<2.0.0",
|
||||
"asyncpg==0.30.0", # Optional
|
||||
"pgvector>=0.3.5,<0.4.0", # Optional
|
||||
"psycopg2>=2.9.10,<3.0.0", # Optional
|
||||
"llama-index-core>=0.12.0", # Optional
|
||||
"deepeval>=2.0.1,<3.0.0", # Optional
|
||||
"transformers>=4.46.3,<5.0.0",
|
||||
"pymilvus>=2.5.0,<3.0.0", # Optional
|
||||
"unstructured[csv,doc,docx,epub,md,odt,org,ppt,pptx,rst,rtf,tsv,xlsx]>=0.16.10,<1.0.0", # Optional
|
||||
"pytest>=7.4.0,<8.0.0",
|
||||
"pytest-asyncio>=0.21.1,<0.22.0",
|
||||
"coverage>=7.3.2,<8.0.0",
|
||||
"mypy>=1.7.1,<2.0.0",
|
||||
"deptry>=0.20.0,<0.21.0",
|
||||
"debugpy==1.8.2",
|
||||
"pylint>=3.0.3,<4.0.0",
|
||||
"ruff>=0.2.2,<0.3.0",
|
||||
"tweepy==4.14.0",
|
||||
"gitpython>=3.1.43,<4.0.0",
|
||||
"cognee",
|
||||
"mcp==1.2.0",
|
||||
]
|
||||
|
||||
[[project.authors]]
|
||||
|
|
@ -79,16 +15,14 @@ name = "Rita Aleksziev"
|
|||
email = "rita@topoteretes.com"
|
||||
|
||||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
requires = [ "hatchling", ]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.uv.sources]
|
||||
cognee = { path = "../../cognee" }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"cognee",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
cognee = "cognee_mcp:main"
|
||||
cognee = "src:main"
|
||||
|
|
|
|||
8
cognee-mcp/src/__init__.py
Normal file
8
cognee-mcp/src/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from .server import main as server_main
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the package."""
|
||||
import asyncio
|
||||
|
||||
asyncio.run(server_main())
|
||||
50
cognee-mcp/src/client.py
Executable file
50
cognee-mcp/src/client.py
Executable file
|
|
@ -0,0 +1,50 @@
|
|||
from datetime import timedelta
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
# Create server parameters for stdio connection
|
||||
server_params = StdioServerParameters(
|
||||
command="uv", # Executable
|
||||
args=["--directory", ".", "run", "cognee"], # Optional command line arguments
|
||||
env=None, # Optional environment variables
|
||||
)
|
||||
|
||||
text = """
|
||||
Artificial intelligence, or AI, is technology that enables computers
|
||||
and machines to simulate human intelligence and problem-solving
|
||||
capabilities.
|
||||
On its own or combined with other technologies (e.g., sensors,
|
||||
geolocation, robotics) AI can perform tasks that would otherwise
|
||||
require human intelligence or intervention. Digital assistants, GPS
|
||||
guidance, autonomous vehicles, and generative AI tools (like Open
|
||||
AI's Chat GPT) are just a few examples of AI in the daily news and
|
||||
our daily lives.
|
||||
As a field of computer science, artificial intelligence encompasses
|
||||
(and is often mentioned together with) machine learning and deep
|
||||
learning. These disciplines involve the development of AI
|
||||
algorithms, modeled after the decision-making processes of the human
|
||||
brain, that can ‘learn’ from available data and make increasingly
|
||||
more accurate classifications or predictions over time.
|
||||
"""
|
||||
|
||||
|
||||
async def run():
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write, timedelta(minutes=3)) as session:
|
||||
await session.initialize()
|
||||
|
||||
toolResult = await session.list_tools()
|
||||
|
||||
toolResult = await session.call_tool("prune", arguments={})
|
||||
|
||||
toolResult = await session.call_tool("cognify", arguments={"text": text})
|
||||
|
||||
toolResult = await session.call_tool("search", arguments={"search_query": "AI"})
|
||||
|
||||
print(f"Cognify result: {toolResult.content}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
219
cognee-mcp/src/server.py
Executable file
219
cognee-mcp/src/server.py
Executable file
|
|
@ -0,0 +1,219 @@
|
|||
import os
|
||||
import cognee
|
||||
import logging
|
||||
import importlib.util
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
# from PIL import Image as PILImage
|
||||
import mcp.types as types
|
||||
from mcp.server import Server, NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
mcp = Server("cognee")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="cognify",
|
||||
description="Cognifies text into knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to cognify",
|
||||
},
|
||||
"graph_model_file": {
|
||||
"type": "string",
|
||||
"description": "The path to the graph model file",
|
||||
},
|
||||
"graph_model_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the graph model",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="search",
|
||||
description="Searches for information in knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
},
|
||||
},
|
||||
"required": ["search_query"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="prune",
|
||||
description="Prunes knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp.call_tool()
|
||||
async def call_tools(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
try:
|
||||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
if name == "cognify":
|
||||
await cognify(
|
||||
text=arguments["text"],
|
||||
graph_model_file=arguments.get("graph_model_file", None),
|
||||
graph_model_name=arguments.get("graph_model_name", None),
|
||||
)
|
||||
|
||||
return [types.TextContent(type="text", text="Ingested")]
|
||||
elif name == "search":
|
||||
search_results = await search(arguments["search_query"])
|
||||
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
elif name == "prune":
|
||||
await prune()
|
||||
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling tool '{name}': {str(e)}")
|
||||
return [types.TextContent(type="text", text=f"Error calling tool '{name}': {str(e)}")]
|
||||
|
||||
|
||||
async def cognify(text: str, graph_model_file: str = None, graph_model_name: str = None) -> str:
|
||||
"""Build knowledge graph from the input text"""
|
||||
if graph_model_file and graph_model_name:
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
await cognee.add(text)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
|
||||
async def search(search_query: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
|
||||
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def prune() -> str:
|
||||
"""Reset the knowledge graph"""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await mcp.run(
|
||||
read_stream=read_stream,
|
||||
write_stream=write_stream,
|
||||
initialization_options=InitializationOptions(
|
||||
server_name="cognee",
|
||||
server_version="0.1.0",
|
||||
capabilities=mcp.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
raise_exceptions=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server failed to start: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# async def visualize() -> Image:
|
||||
# """Visualize the knowledge graph"""
|
||||
# try:
|
||||
# image_path = await cognee.visualize_graph()
|
||||
|
||||
# img = PILImage.open(image_path)
|
||||
# return Image(data=img.tobytes(), format="png")
|
||||
# except (FileNotFoundError, IOError, ValueError) as e:
|
||||
# raise ValueError(f"Failed to create visualization: {str(e)}")
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
)
|
||||
|
||||
return f"Node({node_data})"
|
||||
|
||||
|
||||
def retrieved_edges_to_string(search_results):
|
||||
edge_strings = []
|
||||
for triplet in search_results:
|
||||
node1, edge, node2 = triplet
|
||||
relationship_type = edge["relationship_name"]
|
||||
edge_str = f"{node_to_string(node1)} {relationship_type} {node_to_string(node2)}"
|
||||
edge_strings.append(edge_str)
|
||||
return "\n".join(edge_strings)
|
||||
|
||||
|
||||
def load_class(model_file, model_name):
|
||||
model_file = os.path.abspath(model_file)
|
||||
spec = importlib.util.spec_from_file_location("graph_model", model_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
model_class = getattr(module, model_name)
|
||||
|
||||
return model_class
|
||||
|
||||
|
||||
# def get_freshest_png(directory: str) -> Image:
|
||||
# if not os.path.exists(directory):
|
||||
# raise FileNotFoundError(f"Directory {directory} does not exist")
|
||||
|
||||
# # List all files in 'directory' that end with .png
|
||||
# files = [f for f in os.listdir(directory) if f.endswith(".png")]
|
||||
# if not files:
|
||||
# raise FileNotFoundError("No PNG files found in the given directory.")
|
||||
|
||||
# # Sort by integer value of the filename (minus the '.png')
|
||||
# # Example filename: 1673185134.png -> integer 1673185134
|
||||
# try:
|
||||
# files_sorted = sorted(files, key=lambda x: int(x.replace(".png", "")))
|
||||
# except ValueError as e:
|
||||
# raise ValueError("Invalid PNG filename format. Expected timestamp format.") from e
|
||||
|
||||
# # The "freshest" file has the largest timestamp
|
||||
# freshest_filename = files_sorted[-1]
|
||||
# freshest_path = os.path.join(directory, freshest_filename)
|
||||
|
||||
# # Open the image with PIL and return the PIL Image object
|
||||
# try:
|
||||
# return PILImage.open(freshest_path)
|
||||
# except (IOError, OSError) as e:
|
||||
# raise IOError(f"Failed to open PNG file {freshest_path}") from e
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
3060
cognee-mcp/uv.lock
generated
3060
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -4,7 +4,7 @@ from .api.v1.config.config import config
|
|||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.prune import prune
|
||||
from .api.v1.search import SearchType, get_search_history, search
|
||||
from .api.v1.visualize import visualize
|
||||
from .api.v1.visualize import visualize_graph
|
||||
from .shared.utils import create_cognee_style_network_with_logo
|
||||
|
||||
# Pipelines
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from fastapi import Request
|
||||
|
|
@ -169,6 +169,10 @@ app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["sett
|
|||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
||||
app.include_router(
|
||||
get_code_pipeline_router(), prefix="/api/v1/code-pipeline", tags=["code-pipeline"]
|
||||
)
|
||||
|
||||
|
||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Union, BinaryIO
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines import run_tasks, Task
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata, resolve_data_directories
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
|
|
@ -22,7 +22,7 @@ async def add(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data_with_metadata, dataset_name, user)]
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
||||
|
||||
pipeline = run_tasks(tasks, data, "add_pipeline")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
|
|
@ -10,7 +9,7 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import (
|
||||
enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
|
|
@ -21,6 +20,7 @@ from cognee.tasks.repo_processor import (
|
|||
from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_code, summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
|
|
@ -33,22 +33,9 @@ update_status_lock = asyncio.Lock()
|
|||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
|
@ -68,10 +55,10 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
if include_docs:
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
|
||||
Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
|
|
@ -24,6 +25,7 @@ from cognee.tasks.documents import (
|
|||
)
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage.descriptive_metrics import store_descriptive_metrics
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
|
||||
|
|
@ -36,6 +38,7 @@ async def cognify(
|
|||
datasets: Union[str, list[str]] = None,
|
||||
user: User = None,
|
||||
graph_model: BaseModel = KnowledgeGraph,
|
||||
tasks: list[Task] = None,
|
||||
):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
@ -55,18 +58,19 @@ async def cognify(
|
|||
|
||||
awaitables = []
|
||||
|
||||
if tasks is None:
|
||||
tasks = await get_default_tasks(user, graph_model)
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in existing_datasets_map:
|
||||
awaitables.append(run_cognify_pipeline(dataset, user, graph_model))
|
||||
awaitables.append(run_cognify_pipeline(dataset, user, tasks))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
async def run_cognify_pipeline(
|
||||
dataset: Dataset, user: User, graph_model: BaseModel = KnowledgeGraph
|
||||
):
|
||||
async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
||||
|
||||
document_ids_str = [str(document.id) for document in data_documents]
|
||||
|
|
@ -96,22 +100,12 @@ async def run_cognify_pipeline(
|
|||
)
|
||||
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
if not isinstance(tasks, list):
|
||||
raise ValueError("Tasks must be a list")
|
||||
|
||||
tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
|
||||
]
|
||||
for task in tasks:
|
||||
if not isinstance(task, Task):
|
||||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||
|
||||
|
|
@ -146,3 +140,34 @@ async def run_cognify_pipeline(
|
|||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
||||
|
||||
async def get_default_tasks(
|
||||
user: User = None, graph_model: BaseModel = KnowledgeGraph
|
||||
) -> list[Task]:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(store_descriptive_metrics),
|
||||
]
|
||||
except Exception as error:
|
||||
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
|
||||
raise error
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
code_description_to_code_part_search,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(BaseModel):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(BaseModel):
|
||||
query: str
|
||||
fullInput: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""This endpoint is responsible for running the indexation on code repo."""
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
print(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""This endpoint is responsible for retrieving the context."""
|
||||
try:
|
||||
query = (
|
||||
payload.fullInput.replace("cognee ", "")
|
||||
if payload.fullInput.startswith("cognee ")
|
||||
else payload.fullInput
|
||||
)
|
||||
|
||||
retrieved_codeparts, __ = await code_description_to_code_part_search(
|
||||
query, include_docs=False
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": codepart.attributes["id"],
|
||||
"description": codepart.attributes["id"],
|
||||
"content": codepart.attributes["source_code"],
|
||||
}
|
||||
for codepart in retrieved_codeparts
|
||||
]
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -21,7 +21,7 @@ class SettingsDTO(OutDTO):
|
|||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ async def visualize_graph(label: str = "name"):
|
|||
logging.info(graph_data)
|
||||
|
||||
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
||||
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
|
||||
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -1,4 +1,14 @@
|
|||
class EmbeddingException(Exception):
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class EmbeddingException(CogneeApiError):
|
||||
"""Custom exception for handling embedding-related errors."""
|
||||
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Embedding Exception.",
|
||||
name: str = "EmbeddingException",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -62,10 +62,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def add_node(self, node: DataPoint):
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = dedent("""MERGE (node {id: $node_id})
|
||||
query = dedent(
|
||||
"""MERGE (node {id: $node_id})
|
||||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"node_id": str(node.id),
|
||||
|
|
@ -182,13 +184,15 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||
query = dedent(
|
||||
"""MATCH (from_node {id: $from_node}),
|
||||
(to_node {id: $to_node})
|
||||
MERGE (from_node)-[r]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(from_node),
|
||||
|
|
@ -201,12 +205,20 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
|
||||
RETURN rel
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL apoc.merge.relationship(
|
||||
from_node,
|
||||
edge.relationship_name,
|
||||
{
|
||||
source_node_id: edge.from_node,
|
||||
target_node_id: edge.to_node
|
||||
},
|
||||
edge.properties,
|
||||
to_node
|
||||
) YIELD rel
|
||||
RETURN rel"""
|
||||
|
||||
edges = [
|
||||
{
|
||||
|
|
@ -426,6 +438,15 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return serialized_properties
|
||||
|
||||
async def get_model_independent_graph_data(self):
|
||||
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
|
||||
nodes = await self.query(query_nodes)
|
||||
|
||||
query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
|
||||
edges = await self.query(query_edges)
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_graph_data(self):
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
|
||||
|
|
|
|||
|
|
@ -247,10 +247,11 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if not file_path:
|
||||
file_path = self.filename
|
||||
|
||||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
|
||||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph, edges="links")
|
||||
|
||||
async with aiofiles.open(file_path, "w") as file:
|
||||
await file.write(json.dumps(graph_data, cls=JSONEncoder))
|
||||
json_data = json.dumps(graph_data, cls=JSONEncoder)
|
||||
await file.write(json_data)
|
||||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
"""Asynchronously load the graph from a file in JSON format."""
|
||||
|
|
@ -265,19 +266,32 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
graph_data = json.loads(await file.read())
|
||||
for node in graph_data["nodes"]:
|
||||
try:
|
||||
node["id"] = UUID(node["id"])
|
||||
if not isinstance(node["id"], UUID):
|
||||
node["id"] = UUID(node["id"])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
if "updated_at" in node:
|
||||
|
||||
if isinstance(node.get("updated_at"), int):
|
||||
node["updated_at"] = datetime.fromtimestamp(
|
||||
node["updated_at"] / 1000, tz=timezone.utc
|
||||
)
|
||||
elif isinstance(node.get("updated_at"), str):
|
||||
node["updated_at"] = datetime.strptime(
|
||||
node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
|
||||
for edge in graph_data["links"]:
|
||||
try:
|
||||
source_id = UUID(edge["source"])
|
||||
target_id = UUID(edge["target"])
|
||||
if not isinstance(edge["source"], UUID):
|
||||
source_id = UUID(edge["source"])
|
||||
else:
|
||||
source_id = edge["source"]
|
||||
|
||||
if not isinstance(edge["target"], UUID):
|
||||
target_id = UUID(edge["target"])
|
||||
else:
|
||||
target_id = edge["target"]
|
||||
|
||||
edge["source"] = source_id
|
||||
edge["target"] = target_id
|
||||
|
|
@ -287,12 +301,16 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
print(e)
|
||||
pass
|
||||
|
||||
if "updated_at" in edge:
|
||||
if isinstance(edge["updated_at"], int): # Handle timestamp in milliseconds
|
||||
edge["updated_at"] = datetime.fromtimestamp(
|
||||
edge["updated_at"] / 1000, tz=timezone.utc
|
||||
)
|
||||
elif isinstance(edge["updated_at"], str):
|
||||
edge["updated_at"] = datetime.strptime(
|
||||
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
|
||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data, edges="links")
|
||||
|
||||
for node_id, node_data in self.graph.nodes(data=True):
|
||||
node_data["id"] = node_id
|
||||
|
|
|
|||
|
|
@ -88,23 +88,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
}
|
||||
)
|
||||
|
||||
return dedent(f"""
|
||||
return dedent(
|
||||
f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||
properties = await self.stringify_properties(edge[3])
|
||||
properties = f"{{{properties}}}"
|
||||
|
||||
return dedent(f"""
|
||||
return dedent(
|
||||
f"""
|
||||
MERGE (source {{id:'{edge[0]}'}})
|
||||
MERGE (target {{id: '{edge[1]}'}})
|
||||
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
||||
ON MATCH SET edge.updated_at = timestamp()
|
||||
ON CREATE SET edge.updated_at = timestamp()
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
async def create_collection(self, collection_name: str):
|
||||
pass
|
||||
|
|
@ -195,12 +199,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
self.query(query)
|
||||
|
||||
async def has_edges(self, edges):
|
||||
query = dedent("""
|
||||
query = dedent(
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
params = {
|
||||
"edges": [
|
||||
|
|
@ -279,14 +285,16 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
[label, attribute_name] = collection_name.split(".")
|
||||
|
||||
query = dedent(f"""
|
||||
query = dedent(
|
||||
f"""
|
||||
CALL db.idx.vector.queryNodes(
|
||||
'{label}',
|
||||
'{attribute_name}',
|
||||
{limit},
|
||||
vecf32({query_vector})
|
||||
) YIELD node, score
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
result = self.query(query)
|
||||
|
||||
|
|
|
|||
|
|
@ -77,14 +77,51 @@ class SQLAlchemyAdapter:
|
|||
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
|
||||
)
|
||||
|
||||
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
|
||||
columns = ", ".join(data[0].keys())
|
||||
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data])
|
||||
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};")
|
||||
async def insert_data(
|
||||
self,
|
||||
table_name: str,
|
||||
data: list[dict],
|
||||
schema_name: Optional[str] = "public",
|
||||
) -> int:
|
||||
"""
|
||||
Insert data into specified table using SQLAlchemy Core with batch optimization
|
||||
Returns number of inserted rows
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
await connection.execute(insert_query, data)
|
||||
await connection.close()
|
||||
Usage Example:
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||
from uuid import UUID
|
||||
db = get_relational_engine()
|
||||
table_name = "groups"
|
||||
data = {
|
||||
"id": UUID("c70a3cec-3309-44df-8ee6-eced820cf438"),
|
||||
"name": "test"
|
||||
}
|
||||
await db.insert_data(table_name, data)
|
||||
"""
|
||||
if not data:
|
||||
logger.info("No data provided for insertion")
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Use SQLAlchemy Core insert with execution options
|
||||
async with self.engine.begin() as conn:
|
||||
# Dialect-agnostic table reference
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
|
||||
# so must be enabled for each database connection/session separately.
|
||||
await conn.execute(text("PRAGMA foreign_keys=ON"))
|
||||
table = await self.get_table(table_name) # SQLite ignores schemas
|
||||
else:
|
||||
table = await self.get_table(table_name, schema_name)
|
||||
|
||||
result = await conn.execute(table.insert().values(data))
|
||||
|
||||
# Return rowcount for validation
|
||||
return result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Insert failed: {str(e)}")
|
||||
raise e # Re-raise for error handling upstream
|
||||
|
||||
async def get_schema_list(self) -> List[str]:
|
||||
"""
|
||||
|
|
@ -93,10 +130,12 @@ class SQLAlchemyAdapter:
|
|||
if self.engine.dialect.name == "postgresql":
|
||||
async with self.engine.begin() as connection:
|
||||
result = await connection.execute(
|
||||
text("""
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
|
||||
""")
|
||||
"""
|
||||
)
|
||||
)
|
||||
return [schema[0] for schema in result.fetchall()]
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ import litellm
|
|||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = logging.getLogger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -15,32 +18,38 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_key: str
|
||||
endpoint: str
|
||||
api_version: str
|
||||
provider: str
|
||||
model: str
|
||||
dimensions: int
|
||||
mock: bool
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = "text-embedding-3-large",
|
||||
model: Optional[str] = "openai/text-embedding-3-large",
|
||||
provider: str = "openai",
|
||||
dimensions: Optional[int] = 3072,
|
||||
api_key: str = None,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
retry_count = 0
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def exponential_backoff(attempt):
|
||||
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
|
||||
|
|
@ -55,14 +64,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
self.retry_count = 0
|
||||
self.retry_count = 0 # Reset retry count on successful call
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
|
|
@ -90,13 +99,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
raise Exception("Rate limit exceeded and no more retries left.")
|
||||
|
||||
await exponential_backoff(self.retry_count)
|
||||
|
||||
self.retry_count += 1
|
||||
|
||||
return await self.embed_text(text)
|
||||
|
||||
except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError):
|
||||
raise EmbeddingException("Failed to index data points.")
|
||||
except (
|
||||
litellm.exceptions.BadRequestError,
|
||||
litellm.exceptions.NotFoundError,
|
||||
) as e:
|
||||
logger.error(f"Embedding error with model {self.model}: {str(e)}")
|
||||
raise EmbeddingException(f"Failed to index data points using model {self.model}")
|
||||
|
||||
except Exception as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
|
|
@ -104,3 +116,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
||||
def get_tokenizer(self):
|
||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||
# If model also contains provider information, extract only model information
|
||||
model = self.model.split("/")[-1]
|
||||
|
||||
if "openai" in self.provider.lower():
|
||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||
|
||||
|
||||
class EmbeddingConfig(BaseSettings):
|
||||
embedding_model: Optional[str] = "text-embedding-3-large"
|
||||
embedding_provider: Optional[str] = "openai"
|
||||
embedding_model: Optional[str] = "openai/text-embedding-3-large"
|
||||
embedding_dimensions: Optional[int] = 3072
|
||||
embedding_endpoint: Optional[str] = None
|
||||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
|
||||
embedding_max_tokens: Optional[int] = 8191
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
|
||||
return LiteLLMEmbeddingEngine(
|
||||
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||
provider=config.embedding_provider,
|
||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
||||
endpoint=config.embedding_endpoint,
|
||||
api_version=config.embedding_api_version,
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -152,7 +152,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
results = await collection.vector_search(query_vector).to_pandas()
|
||||
collection_size = await collection.count_rows()
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(collection_size).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
|
|
@ -250,9 +252,16 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
)
|
||||
|
||||
async def prune(self):
|
||||
# Clean up the database if it was set up as temporary
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
|
||||
for collection_name in collection_names:
|
||||
collection = await connection.open_table(collection_name)
|
||||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
||||
LocalStorage.remove_all(self.url)
|
||||
|
||||
def get_data_point_schema(self, model_type):
|
||||
return copy_model(
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ from ...relational.ModelBase import Base
|
|||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from .serialize_data import serialize_data
|
||||
from ..utils import normalize_distances
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
|
|
@ -247,12 +247,22 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
vector_list.append(
|
||||
{
|
||||
"id": UUID(str(vector.id)),
|
||||
"payload": vector.payload,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=UUID(str(row.id)), payload=row.payload, score=row.similarity)
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,31 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
import pickle
|
||||
|
||||
|
||||
# Define metadata type
|
||||
class MetaData(TypedDict):
|
||||
index_fields: list[str]
|
||||
|
||||
|
||||
# Updated DataPoint model with versioning and new fields
|
||||
class DataPoint(BaseModel):
|
||||
__tablename__ = "data_point"
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
||||
created_at: int = Field(
|
||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
updated_at: int = Field(
|
||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
version: int = 1 # Default version
|
||||
topological_rank: Optional[int] = 0
|
||||
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
|
||||
|
||||
# class Config:
|
||||
# underscore_attrs_are_private = True
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_data(self, data_point):
|
||||
if (
|
||||
|
|
@ -31,11 +37,11 @@ class DataPoint(BaseModel):
|
|||
|
||||
if isinstance(attribute, str):
|
||||
return attribute.strip()
|
||||
else:
|
||||
return attribute
|
||||
return attribute
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_properties(self, data_point):
|
||||
"""Retrieve all embeddable properties."""
|
||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
||||
return [
|
||||
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
|
||||
|
|
@ -45,4 +51,40 @@ class DataPoint(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def get_embeddable_property_names(self, data_point):
|
||||
"""Retrieve names of embeddable properties."""
|
||||
return data_point._metadata["index_fields"] or []
|
||||
|
||||
def update_version(self):
|
||||
"""Update the version and updated_at timestamp."""
|
||||
self.version += 1
|
||||
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
# JSON Serialization
|
||||
def to_json(self) -> str:
|
||||
"""Serialize the instance to a JSON string."""
|
||||
return self.json()
|
||||
|
||||
@classmethod
|
||||
def from_json(self, json_str: str):
|
||||
"""Deserialize the instance from a JSON string."""
|
||||
return self.model_validate_json(json_str)
|
||||
|
||||
# Pickle Serialization
|
||||
def to_pickle(self) -> bytes:
|
||||
"""Serialize the instance to pickle-compatible bytes."""
|
||||
return pickle.dumps(self.dict())
|
||||
|
||||
@classmethod
|
||||
def from_pickle(self, pickled_data: bytes):
|
||||
"""Deserialize the instance from pickled bytes."""
|
||||
data = pickle.loads(pickled_data)
|
||||
return self(**data)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Serialize model to a dictionary."""
|
||||
return self.model_dump(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
|
||||
"""Deserialize model from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .config import get_llm_config
|
||||
from .utils import get_max_chunk_tokens
|
||||
|
|
|
|||
|
|
@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface):
|
|||
name = "Anthropic"
|
||||
model: str
|
||||
|
||||
def __init__(self, model: str = None):
|
||||
def __init__(self, max_tokens: int, model: str = None):
|
||||
self.aclient = instructor.patch(
|
||||
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS
|
||||
)
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_api_version: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
llm_max_tokens: int = 16384
|
||||
transcription_model: str = "whisper-1"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
|
@ -24,6 +25,7 @@ class LLMConfig(BaseSettings):
|
|||
"api_version": self.llm_api_version,
|
||||
"temperature": self.llm_temperature,
|
||||
"streaming": self.llm_streaming,
|
||||
"max_tokens": self.llm_max_tokens,
|
||||
"transcription_model": self.transcription_model,
|
||||
}
|
||||
|
||||
|
|
|
|||
0
cognee/infrastructure/llm/gemini/__init__.py
Normal file
0
cognee/infrastructure/llm/gemini/__init__.py
Normal file
155
cognee/infrastructure/llm/gemini/adapter.py
Normal file
155
cognee/infrastructure/llm/gemini/adapter.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import litellm
|
||||
import asyncio
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
|
||||
|
||||
class GeminiAdapter(LLMInterface):
|
||||
MAX_RETRIES = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
try:
|
||||
response_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"nodes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"id": {"type": "string"},
|
||||
"label": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "type", "description", "id", "label"],
|
||||
},
|
||||
},
|
||||
"edges": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_node_id": {"type": "string"},
|
||||
"target_node_id": {"type": "string"},
|
||||
"relationship_name": {"type": "string"},
|
||||
},
|
||||
"required": ["source_node_id", "target_node_id", "relationship_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["summary", "description", "nodes", "edges"],
|
||||
}
|
||||
|
||||
simplified_prompt = f"""
|
||||
{system_prompt}
|
||||
|
||||
IMPORTANT: Your response must be a valid JSON object with these required fields:
|
||||
1. summary: A brief summary
|
||||
2. description: A detailed description
|
||||
3. nodes: Array of nodes with name, type, description, id, and label
|
||||
4. edges: Array of edges with source_node_id, target_node_id, and relationship_name
|
||||
|
||||
Example structure:
|
||||
{{
|
||||
"summary": "Brief summary",
|
||||
"description": "Detailed description",
|
||||
"nodes": [
|
||||
{{
|
||||
"name": "Example Node",
|
||||
"type": "Concept",
|
||||
"description": "Node description",
|
||||
"id": "example-id",
|
||||
"label": "Concept"
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{
|
||||
"source_node_id": "source-id",
|
||||
"target_node_id": "target-id",
|
||||
"relationship_name": "relates_to"
|
||||
}}
|
||||
]
|
||||
}}"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": simplified_prompt},
|
||||
{"role": "user", "content": text_input},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=f"{self.model}",
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0.1,
|
||||
response_format={"type": "json_object", "schema": response_schema},
|
||||
timeout=10,
|
||||
num_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
return formatted_prompt
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import asyncio
|
||||
from typing import List, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
|
|
@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
|
||||
def __init__(self, endpoint, api_key: str, model: str, name: str):
|
||||
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
llm_config = get_llm_config()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class LLMProvider(Enum):
|
|||
OLLAMA = "ollama"
|
||||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
|
|
@ -20,6 +21,15 @@ def get_llm_client():
|
|||
|
||||
provider = LLMProvider(llm_config.llm_provider)
|
||||
|
||||
# Check if max_token value is defined in liteLLM for given model
|
||||
# if not use value from cognee configuration
|
||||
from cognee.infrastructure.llm.utils import (
|
||||
get_model_max_tokens,
|
||||
) # imported here to avoid circular imports
|
||||
|
||||
model_max_tokens = get_model_max_tokens(llm_config.llm_model)
|
||||
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens
|
||||
|
||||
if provider == LLMProvider.OPENAI:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
|
@ -32,6 +42,7 @@ def get_llm_client():
|
|||
api_version=llm_config.llm_api_version,
|
||||
model=llm_config.llm_model,
|
||||
transcription_model=llm_config.transcription_model,
|
||||
max_tokens=max_tokens,
|
||||
streaming=llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
|
|
@ -42,13 +53,17 @@ def get_llm_client():
|
|||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
|
||||
return GenericAPIAdapter(
|
||||
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama"
|
||||
llm_config.llm_endpoint,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
from .anthropic.adapter import AnthropicAdapter
|
||||
|
||||
return AnthropicAdapter(llm_config.llm_model)
|
||||
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
if llm_config.llm_api_key is None:
|
||||
|
|
@ -57,7 +72,26 @@ def get_llm_client():
|
|||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
|
||||
return GenericAPIAdapter(
|
||||
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom"
|
||||
llm_config.llm_endpoint,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.GEMINI:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
||||
from .gemini.adapter import GeminiAdapter
|
||||
|
||||
return GeminiAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_tokens=max_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
streaming=llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_key: str
|
||||
api_version: str
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -32,6 +34,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_version: str,
|
||||
model: str,
|
||||
transcription_model: str,
|
||||
max_tokens: int,
|
||||
streaming: bool = False,
|
||||
):
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
|
|
@ -41,6 +44,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.max_tokens = max_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
@observe(as_type="generation")
|
||||
|
|
@ -66,7 +70,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=5,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
@observe
|
||||
|
|
@ -92,7 +96,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=5,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
def create_transcript(self, input):
|
||||
|
|
@ -110,7 +114,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_retries=5,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
return transcription
|
||||
|
|
@ -142,7 +146,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_tokens=300,
|
||||
max_retries=5,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
|
|
|
|||
9
cognee/infrastructure/llm/prompts/llm_judge_prompts.py
Normal file
9
cognee/infrastructure/llm/prompts/llm_judge_prompts.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# LLM-as-a-judge metrics as described here: https://arxiv.org/abs/2404.16130
|
||||
|
||||
llm_judge_prompts = {
|
||||
"correctness": "Determine whether the actual output is factually correct based on the expected output.",
|
||||
"comprehensiveness": "Determine how much detail the answer provides to cover all the aspects and details of the question.",
|
||||
"diversity": "Determine how varied and rich the answer is in providing different perspectives and insights on the question.",
|
||||
"empowerment": "Determine how well the answer helps the reader understand and make informed judgements about the topic.",
|
||||
"directness": "Determine how specifically and clearly the answer addresses the question.",
|
||||
}
|
||||
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .adapter import GeminiTokenizer
|
||||
45
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
45
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from typing import List, Any, Union
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class GeminiTokenizer(TokenizerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Get LLM API key from config
|
||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_single_token(self, encoding: int):
|
||||
# Gemini tokenizer doesn't have the option to decode tokens
|
||||
raise NotImplementedError
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
import google.generativeai as genai
|
||||
|
||||
return len(genai.embed_content(model=f"models/{self.model}", content=text))
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .adapter import HuggingFaceTokenizer
|
||||
38
cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py
Normal file
38
cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from typing import List, Any
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class HuggingFaceTokenizer(TokenizerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Import here to make it an optional dependency
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
return tokens
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
return len(self.tokenizer.tokenize(text))
|
||||
|
||||
def decode_single_token(self, encoding: int):
|
||||
# Gemini tokenizer doesn't have the option to decode tokens
|
||||
raise NotImplementedError
|
||||
1
cognee/infrastructure/llm/tokenizer/TikToken/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/TikToken/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .adapter import TikTokenTokenizer
|
||||
72
cognee/infrastructure/llm/tokenizer/TikToken/adapter.py
Normal file
72
cognee/infrastructure/llm/tokenizer/TikToken/adapter.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from typing import List, Any
|
||||
import tiktoken
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class TikTokenTokenizer(TokenizerInterface):
|
||||
"""
|
||||
Tokenizer adapter for OpenAI.
|
||||
Inteded to be used as part of LLM Embedding and LLM Adapters classes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 8191,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
# Initialize TikToken for GPT based on model
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
# Using TikToken's method to tokenize text
|
||||
token_ids = self.tokenizer.encode(text)
|
||||
return token_ids
|
||||
|
||||
def decode_token_list(self, tokens: List[Any]) -> List[Any]:
|
||||
if not isinstance(tokens, list):
|
||||
tokens = [tokens]
|
||||
return [self.tokenizer.decode(i) for i in tokens]
|
||||
|
||||
def decode_single_token(self, token: int):
|
||||
return self.tokenizer.decode_single_token_bytes(token).decode("utf-8", errors="replace")
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
num_tokens = len(self.tokenizer.encode(text))
|
||||
return num_tokens
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
"""
|
||||
Trims the text so that the number of tokens does not exceed max_tokens.
|
||||
|
||||
Args:
|
||||
text (str): Original text string to be trimmed.
|
||||
|
||||
Returns:
|
||||
str: Trimmed version of text or original text if under the limit.
|
||||
"""
|
||||
# First check the number of tokens
|
||||
num_tokens = self.count_tokens(text)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= self.max_tokens:
|
||||
return text
|
||||
|
||||
# If the number exceeds the limit, trim the text
|
||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||
encoded_text = self.tokenizer.encode(text)
|
||||
trimmed_encoded_text = encoded_text[: self.max_tokens]
|
||||
# Decoding the trimmed text
|
||||
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
||||
return trimmed_text
|
||||
1
cognee/infrastructure/llm/tokenizer/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .tokenizer_interface import TokenizerInterface
|
||||
18
cognee/infrastructure/llm/tokenizer/tokenizer_interface.py
Normal file
18
cognee/infrastructure/llm/tokenizer/tokenizer_interface.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from typing import List, Protocol, Any
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class TokenizerInterface(Protocol):
|
||||
"""Tokenizer interface"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def count_tokens(self, text: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode_single_token(self, token: int) -> str:
|
||||
raise NotImplementedError
|
||||
38
cognee/infrastructure/llm/utils.py
Normal file
38
cognee/infrastructure/llm/utils.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import logging
|
||||
import litellm
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_max_chunk_tokens():
|
||||
# Calculate max chunk size based on the following formula
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
llm_client = get_llm_client()
|
||||
|
||||
# We need to make sure chunk size won't take more than half of LLM max context token size
|
||||
# but it also can't be bigger than the embedding engine max token size
|
||||
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
|
||||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||
|
||||
return max_chunk_tokens
|
||||
|
||||
|
||||
def get_model_max_tokens(model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model_name: name of LLM or embedding model
|
||||
|
||||
Returns: Number of max tokens of model, or None if model is unknown
|
||||
"""
|
||||
max_tokens = None
|
||||
|
||||
if model_name in litellm.model_cost:
|
||||
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
||||
else:
|
||||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
|
|
@ -14,17 +14,15 @@ class TextChunker:
|
|||
chunk_size = 0
|
||||
token_count = 0
|
||||
|
||||
def __init__(
|
||||
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
|
||||
):
|
||||
def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024):
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
self.max_tokens = max_tokens if max_tokens else float("inf")
|
||||
self.max_chunk_tokens = max_chunk_tokens
|
||||
|
||||
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
|
||||
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens
|
||||
return word_count_fits and token_count_fits
|
||||
|
||||
def read(self):
|
||||
|
|
@ -32,7 +30,7 @@ class TextChunker:
|
|||
for content_text in self.get_text():
|
||||
for chunk_data in chunk_by_paragraph(
|
||||
content_text,
|
||||
self.max_tokens,
|
||||
self.max_chunk_tokens,
|
||||
self.max_chunk_size,
|
||||
batch_paragraphs=True,
|
||||
):
|
||||
|
|
@ -48,13 +46,13 @@ class TextChunker:
|
|||
id=chunk_data["chunk_id"],
|
||||
text=chunk_data["text"],
|
||||
word_count=chunk_data["word_count"],
|
||||
token_count=chunk_data["token_count"],
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
_metadata={
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id,
|
||||
},
|
||||
)
|
||||
paragraph_chunks = []
|
||||
|
|
@ -68,13 +66,13 @@ class TextChunker:
|
|||
),
|
||||
text=chunk_text,
|
||||
word_count=self.chunk_size,
|
||||
token_count=self.token_count,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains=[],
|
||||
_metadata={
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -91,11 +89,12 @@ class TextChunker:
|
|||
id=uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text=" ".join(chunk["text"] for chunk in paragraph_chunks),
|
||||
word_count=self.chunk_size,
|
||||
token_count=self.token_count,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains=[],
|
||||
_metadata={"index_fields": ["text"], "metadata_id": self.document.metadata_id},
|
||||
_metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ class DocumentChunk(DataPoint):
|
|||
__tablename__ = "document_chunk"
|
||||
text: str
|
||||
word_count: int
|
||||
token_count: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import os
|
|||
class CognifyConfig(BaseSettings):
|
||||
classification_model: object = DefaultContentPrediction
|
||||
summarization_model: object = SummarizedContent
|
||||
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
from sqlalchemy import UUID, Column, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, relationship
|
||||
from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
from .DatasetData import DatasetData
|
||||
from .Metadata import Metadata
|
||||
|
||||
|
||||
class Data(Base):
|
||||
|
|
@ -21,6 +19,8 @@ class Data(Base):
|
|||
raw_data_location = Column(String)
|
||||
owner_id = Column(UUID, index=True)
|
||||
content_hash = Column(String)
|
||||
external_metadata = Column(JSON)
|
||||
token_count = Column(Integer)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
|
@ -32,13 +32,6 @@ class Data(Base):
|
|||
cascade="all, delete",
|
||||
)
|
||||
|
||||
metadata_relationship = relationship(
|
||||
"Metadata",
|
||||
back_populates="data",
|
||||
lazy="noload",
|
||||
cascade="all, delete",
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
|
|
|
|||
27
cognee/modules/data/models/GraphMetrics.py
Normal file
27
cognee/modules/data/models/GraphMetrics.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, JSON, UUID
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class GraphMetrics(Base):
|
||||
__tablename__ = "graph_metrics_table"
|
||||
|
||||
# TODO: Change ID to reflect unique id of graph database
|
||||
id = Column(UUID, primary_key=True, default=uuid4)
|
||||
num_tokens = Column(Integer, nullable=True)
|
||||
num_nodes = Column(Integer, nullable=True)
|
||||
num_edges = Column(Integer, nullable=True)
|
||||
mean_degree = Column(Float, nullable=True)
|
||||
edge_density = Column(Float, nullable=True)
|
||||
num_connected_components = Column(Integer, nullable=True)
|
||||
sizes_of_connected_components = Column(JSON, nullable=True)
|
||||
num_selfloops = Column(Integer, nullable=True)
|
||||
diameter = Column(Integer, nullable=True)
|
||||
avg_shortest_path_length = Column(Float, nullable=True)
|
||||
avg_clustering = Column(Float, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Metadata(Base):
|
||||
__tablename__ = "metadata_table"
|
||||
|
||||
id = Column(UUID, primary_key=True, default=uuid4)
|
||||
metadata_repr = Column(String)
|
||||
metadata_source = Column(String)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"), primary_key=False)
|
||||
data = relationship("Data", back_populates="metadata_relationship")
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import warnings
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def delete_metadata(metadata_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = await session.get(Metadata, metadata_id)
|
||||
if metadata is None:
|
||||
warnings.warn(f"metadata for metadata_id: {metadata_id} not found")
|
||||
|
||||
session.delete(metadata)
|
||||
session.commit()
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def get_metadata(metadata_id: UUID) -> Metadata:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = await session.get(Metadata, metadata_id)
|
||||
|
||||
return metadata
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from typing import Any, BinaryIO, Union
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import FileMetadata
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def write_metadata(
|
||||
data_item: Union[BinaryIO, str, Any], data_id: UUID, file_metadata: FileMetadata
|
||||
) -> UUID:
|
||||
metadata_dict = get_metadata_dict(data_item, file_metadata)
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = (
|
||||
await session.execute(select(Metadata).filter(Metadata.data_id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if metadata is not None:
|
||||
metadata.metadata_repr = json.dumps(metadata_dict)
|
||||
metadata.metadata_source = parse_type(type(data_item))
|
||||
await session.merge(metadata)
|
||||
else:
|
||||
metadata = Metadata(
|
||||
id=data_id,
|
||||
metadata_repr=json.dumps(metadata_dict),
|
||||
metadata_source=parse_type(type(data_item)),
|
||||
data_id=data_id,
|
||||
)
|
||||
session.add(metadata)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
def parse_type(type_: Any) -> str:
|
||||
pattern = r".+'([\w_\.]+)'"
|
||||
match = re.search(pattern, str(type_))
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
raise Exception(f"type: {type_} could not be parsed")
|
||||
|
||||
|
||||
def get_metadata_dict(
|
||||
data_item: Union[BinaryIO, str, Any], file_metadata: FileMetadata
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(data_item, str):
|
||||
return file_metadata
|
||||
elif isinstance(data_item, BinaryIO):
|
||||
return file_metadata
|
||||
elif hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||
return {**file_metadata, **data_item.dict()}
|
||||
else:
|
||||
warnings.warn(
|
||||
f"metadata of type {type(data_item)}: {str(data_item)[:20]}... does not have dict method. Defaulting to string method"
|
||||
)
|
||||
try:
|
||||
return {**dict(file_metadata), "content": str(data_item)}
|
||||
except Exception as e:
|
||||
raise Exception(f"Could not cast metadata to string: {e}")
|
||||
|
|
@ -13,14 +13,14 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return result.text
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from cognee.infrastructure.engine import DataPoint
|
|||
class Document(DataPoint):
|
||||
name: str
|
||||
raw_data_location: str
|
||||
metadata_id: UUID
|
||||
external_metadata: Optional[str]
|
||||
mime_type: str
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
||||
|
||||
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
|
||||
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .Document import Document
|
|||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -19,7 +19,7 @@ class PdfDocument(Document):
|
|||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from .Document import Document
|
|||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -21,7 +21,7 @@ class TextDocument(Document):
|
|||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .Document import Document
|
|||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
|
||||
def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
|
@ -29,6 +29,8 @@ class UnstructuredDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
|
||||
chunker = TextChunker(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -113,8 +113,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Error projecting graph: {e}")
|
||||
raise e
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error: {ex}")
|
||||
raise ex
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
for category, scored_results in node_distances.items():
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ async def get_graph_from_model(
|
|||
added_nodes: dict,
|
||||
added_edges: dict,
|
||||
visited_properties: dict = None,
|
||||
only_root=False,
|
||||
include_root=True,
|
||||
):
|
||||
if str(data_point.id) in added_nodes:
|
||||
|
|
@ -98,7 +97,7 @@ async def get_graph_from_model(
|
|||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
if str(field_value.id) in added_nodes or only_root:
|
||||
if str(field_value.id) in added_nodes:
|
||||
continue
|
||||
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ async def code_description_to_code_part(
|
|||
"Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k
|
||||
)
|
||||
|
||||
context_from_documents = ""
|
||||
|
||||
try:
|
||||
if include_docs:
|
||||
search_results = await search(SearchType.INSIGHTS, query_text=query)
|
||||
|
|
@ -131,14 +133,7 @@ async def code_description_to_code_part(
|
|||
len(code_pieces_to_return),
|
||||
)
|
||||
|
||||
context = ""
|
||||
for code_piece in code_pieces_to_return:
|
||||
context = context + code_piece.get_attribute("source_code")
|
||||
|
||||
if include_docs:
|
||||
context = context_from_documents + context
|
||||
|
||||
return context
|
||||
return code_pieces_to_return, context_from_documents
|
||||
|
||||
except Exception as exec_error:
|
||||
logging.error(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class ModelName(Enum):
|
|||
openai = "openai"
|
||||
ollama = "ollama"
|
||||
anthropic = "anthropic"
|
||||
gemini = "gemini"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
|
|
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
|
|||
"value": "anthropic",
|
||||
"label": "Anthropic",
|
||||
},
|
||||
{
|
||||
"value": "gemini",
|
||||
"label": "Gemini",
|
||||
},
|
||||
]
|
||||
|
||||
return SettingsDict.model_validate(
|
||||
|
|
@ -136,6 +141,12 @@ def get_settings() -> SettingsDict:
|
|||
"label": "Claude 3 Haiku",
|
||||
},
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"value": "gemini-2.0-flash-exp",
|
||||
"label": "Gemini 2.0 Flash",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
vector_db={
|
||||
|
|
|
|||
|
|
@ -10,10 +10,6 @@ import graphistry
|
|||
import networkx as nx
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import tiktoken
|
||||
import nltk
|
||||
import base64
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
|
@ -23,13 +19,40 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
|
||||
from uuid import uuid4
|
||||
import pathlib
|
||||
|
||||
import nltk
|
||||
from cognee.shared.exceptions import IngestionError
|
||||
|
||||
# Analytics Proxy Url, currently hosted by Vercel
|
||||
proxy_url = "https://test.prometh.ai"
|
||||
|
||||
|
||||
def get_entities(tagged_tokens):
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
from nltk.chunk import ne_chunk
|
||||
|
||||
return ne_chunk(tagged_tokens)
|
||||
|
||||
|
||||
def extract_pos_tags(sentence):
|
||||
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
||||
|
||||
# Ensure that the necessary NLTK resources are downloaded
|
||||
nltk.download("words", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Tokenize the sentence into words
|
||||
tokens = word_tokenize(sentence)
|
||||
|
||||
# Tag each word with its corresponding POS tag
|
||||
pos_tags = pos_tag(tokens)
|
||||
|
||||
return pos_tags
|
||||
|
||||
|
||||
def get_anonymous_id():
|
||||
"""Creates or reads a anonymous user id"""
|
||||
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
|
||||
|
|
@ -75,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
|
|||
print(f"Error sending telemetry through proxy: {response.status_code}")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
|
||||
# tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.encoding_for_model(encoding_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
|
||||
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
||||
h = hashlib.md5()
|
||||
|
||||
|
|
@ -109,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
|||
raise IngestionError(message=f"Failed to load data from {file}: {e}")
|
||||
|
||||
|
||||
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
|
||||
"""
|
||||
Trims the text so that the number of tokens does not exceed max_tokens.
|
||||
|
||||
Args:
|
||||
text (str): Original text string to be trimmed.
|
||||
max_tokens (int): Maximum number of tokens allowed.
|
||||
encoding_name (str): The name of the token encoding to use.
|
||||
|
||||
Returns:
|
||||
str: Trimmed version of text or original text if under the limit.
|
||||
"""
|
||||
# First check the number of tokens
|
||||
num_tokens = num_tokens_from_string(text, encoding_name)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# If the number exceeds the limit, trim the text
|
||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
|
||||
trimmed_encoded_text = encoded_text[:max_tokens]
|
||||
# Decoding the trimmed text
|
||||
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
|
||||
return trimmed_text
|
||||
|
||||
|
||||
def generate_color_palette(unique_layers):
|
||||
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
|
||||
colors = [colormap(i) for i in range(len(unique_layers))]
|
||||
|
|
@ -243,33 +229,6 @@ async def render_graph(
|
|||
# return df.replace([np.inf, -np.inf, np.nan], None)
|
||||
|
||||
|
||||
def get_entities(tagged_tokens):
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
from nltk.chunk import ne_chunk
|
||||
|
||||
return ne_chunk(tagged_tokens)
|
||||
|
||||
|
||||
def extract_pos_tags(sentence):
|
||||
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
||||
|
||||
# Ensure that the necessary NLTK resources are downloaded
|
||||
nltk.download("words", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Tokenize the sentence into words
|
||||
tokens = word_tokenize(sentence)
|
||||
|
||||
# Tag each word with its corresponding POS tag
|
||||
pos_tags = pos_tag(tokens)
|
||||
|
||||
return pos_tags
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
|
@ -396,6 +355,7 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
from bokeh.embed import file_html
|
||||
from bokeh.resources import CDN
|
||||
from bokeh.io import export_png
|
||||
|
||||
logging.info("Converting graph to serializable format...")
|
||||
G = await convert_to_serializable_graph(G)
|
||||
|
|
@ -445,13 +405,14 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
logging.info(f"Saving visualization to {output_filename}...")
|
||||
html_content = file_html(p, CDN, title)
|
||||
with open(output_filename, "w") as f:
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
# Construct the final output file path
|
||||
output_filepath = os.path.join(home_dir, output_filename)
|
||||
with open(output_filepath, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
logging.info("Visualization complete.")
|
||||
|
||||
if bokeh_object:
|
||||
return p
|
||||
return html_content
|
||||
|
||||
|
||||
|
|
@ -512,7 +473,7 @@ if __name__ == "__main__":
|
|||
G,
|
||||
output_filename="example_network.html",
|
||||
title="Example Cognee Network",
|
||||
node_attribute="group", # Attribute to use for coloring nodes
|
||||
label="group", # Attribute to use for coloring nodes
|
||||
layout_func=nx.spring_layout, # Layout function
|
||||
layout_scale=3.0, # Scale for the layout
|
||||
logo_alpha=0.2,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,29 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|||
|
||||
async def chunk_naive_llm_classifier(
|
||||
data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]
|
||||
):
|
||||
) -> list[DocumentChunk]:
|
||||
"""
|
||||
Classifies a list of document chunks using a specified classification model and updates vector and graph databases with the classification results.
|
||||
|
||||
Vector Database Structure:
|
||||
- Collection Name: `classification`
|
||||
- Payload Schema:
|
||||
- uuid (str): Unique identifier for the classification.
|
||||
- text (str): Text label of the classification.
|
||||
- chunk_id (str): Identifier of the chunk associated with this classification.
|
||||
- document_id (str): Identifier of the document associated with this classification.
|
||||
|
||||
Graph Database Structure:
|
||||
- Nodes:
|
||||
- Represent document chunks, classification types, and classification subtypes.
|
||||
- Edges:
|
||||
- `is_media_type`: Links document chunks to their classification type.
|
||||
- `is_subtype_of`: Links classification subtypes to their parent type.
|
||||
- `is_classified_as`: Links document chunks to their classification subtypes.
|
||||
Notes:
|
||||
- The function assumes that vector and graph database engines (`get_vector_engine` and `get_graph_engine`) are properly initialized and accessible.
|
||||
- Classification labels are processed to ensure uniqueness using UUIDs based on their values.
|
||||
"""
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from typing import Any, Dict, Iterator, Optional, Union
|
||||
from typing import Any, Dict, Iterator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
import tiktoken
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
|
|
@ -10,13 +8,19 @@ from .chunk_by_sentence import chunk_by_sentence
|
|||
|
||||
def chunk_by_paragraph(
|
||||
data: str,
|
||||
max_tokens: Optional[Union[int, float]] = None,
|
||||
max_chunk_tokens,
|
||||
paragraph_length: int = 1024,
|
||||
batch_paragraphs: bool = True,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Chunks text by paragraph while preserving exact text reconstruction capability.
|
||||
When chunks are joined with empty string "", they reproduce the original text exactly.
|
||||
|
||||
Notes:
|
||||
- Tokenization is handled using our tokenization adapters, ensuring compatibility with the vector engine's embedding model.
|
||||
- If `batch_paragraphs` is False, each paragraph will be yielded as a separate chunk.
|
||||
- Handles cases where paragraphs exceed the specified token or word limits by splitting them as needed.
|
||||
- Remaining text at the end of the input will be yielded as a final chunk.
|
||||
"""
|
||||
current_chunk = ""
|
||||
current_word_count = 0
|
||||
|
|
@ -24,24 +28,17 @@ def chunk_by_paragraph(
|
|||
paragraph_ids = []
|
||||
last_cut_type = None
|
||||
current_token_count = 0
|
||||
if not max_tokens:
|
||||
max_tokens = float("inf")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
embedding_model = vector_engine.embedding_engine.model
|
||||
embedding_model = embedding_model.split("/")[-1]
|
||||
|
||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
||||
data, maximum_length=paragraph_length
|
||||
):
|
||||
# Check if this sentence would exceed length limit
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||
token_count = len(tokenizer.encode(sentence))
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
token_count = embedding_engine.tokenizer.count_tokens(sentence)
|
||||
|
||||
if current_word_count > 0 and (
|
||||
current_word_count + word_count > paragraph_length
|
||||
or current_token_count + token_count > max_tokens
|
||||
or current_token_count + token_count > max_chunk_tokens
|
||||
):
|
||||
# Yield current chunk
|
||||
chunk_dict = {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,19 @@
|
|||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
from uuid import uuid4, UUID
|
||||
from typing import Optional, Iterator, Tuple
|
||||
from .chunk_by_word import chunk_by_word
|
||||
|
||||
|
||||
def chunk_by_sentence(data: str, maximum_length: Optional[int] = None):
|
||||
def chunk_by_sentence(
|
||||
data: str, maximum_length: Optional[int] = None
|
||||
) -> Iterator[Tuple[UUID, str, int, Optional[str]]]:
|
||||
"""
|
||||
Splits the input text into sentences based on word-level processing, with optional sentence length constraints.
|
||||
|
||||
Notes:
|
||||
- Relies on the `chunk_by_word` function for word-level tokenization and classification.
|
||||
- Ensures sentences within paragraphs are uniquely identifiable using UUIDs.
|
||||
- Handles cases where the text ends mid-sentence by appending a special "sentence_cut" type.
|
||||
"""
|
||||
sentence = ""
|
||||
paragraph_id = uuid4()
|
||||
word_count = 0
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue