Merge remote-tracking branch 'origin/dev' into feat/modal-parallelization

This commit is contained in:
Boris Arzentar 2025-04-30 15:40:38 +02:00
commit 44cc74994f
55 changed files with 12965 additions and 4198 deletions

View file

@ -24,4 +24,4 @@ runs:
- name: Install dependencies - name: Install dependencies
shell: bash shell: bash
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev

View file

@ -58,8 +58,10 @@ jobs:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
- name: Run unit tests - name: Run unit tests
shell: bash
run: poetry run pytest cognee/tests/unit/ run: poetry run pytest cognee/tests/unit/
env: env:
PYTHONUTF8: 1
LLM_PROVIDER: openai LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
@ -74,10 +76,26 @@ jobs:
- name: Run integration tests - name: Run integration tests
if: ${{ !contains(matrix.os, 'windows') }} if: ${{ !contains(matrix.os, 'windows') }}
shell: bash
run: poetry run pytest cognee/tests/integration/ run: poetry run pytest cognee/tests/integration/
env:
PYTHONUTF8: 1
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
- name: Run default basic pipeline - name: Run default basic pipeline
shell: bash
env: env:
PYTHONUTF8: 1
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
@ -95,6 +113,7 @@ jobs:
run: poetry run python ./cognee/tests/test_library.py run: poetry run python ./cognee/tests/test_library.py
- name: Build with Poetry - name: Build with Poetry
shell: bash
run: poetry build run: poetry build
- name: Install Package - name: Install Package

View file

@ -29,6 +29,7 @@ RUN apt-get update
RUN apt-get install -y \ RUN apt-get install -y \
gcc \ gcc \
build-essential \
libpq-dev libpq-dev
WORKDIR /app WORKDIR /app
@ -40,7 +41,7 @@ RUN pip install poetry
RUN poetry config virtualenvs.create false RUN poetry config virtualenvs.create false
# Install the dependencies using the defined extras # Install the dependencies using the defined extras
RUN poetry install --extras "${POETRY_EXTRAS}" --no-root --without dev RUN poetry install --extras "${POETRY_EXTRAS}" --no-root
# Set the PYTHONPATH environment variable to include the /app directory # Set the PYTHONPATH environment variable to include the /app directory
ENV PYTHONPATH=/app ENV PYTHONPATH=/app

View file

@ -32,6 +32,14 @@ Build dynamic Agent memory using scalable, modular ECL (Extract, Cognify, Load)
More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals) More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals)
<p align="center">
🌐 Available Languages
:
<a href="community/README.pt.md">🇵🇹 Português</a>
·
<a href="community/README.zh.md">🇨🇳 [中文]</a>
</p>
<div style="text-align: center"> <div style="text-align: center">
<img src="https://raw.githubusercontent.com/topoteretes/cognee/refs/heads/main/assets/cognee_benefits.png" alt="Why cognee?" width="50%" /> <img src="https://raw.githubusercontent.com/topoteretes/cognee/refs/heads/main/assets/cognee_benefits.png" alt="Why cognee?" width="50%" />
</div> </div>
@ -50,7 +58,7 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
## Get Started ## Get Started
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1g-Qnx6l_ecHZi0IOw23rg0qC4TYvEvWZ?usp=sharing">notebook</a> or <a href="https://github.com/topoteretes/cognee-starter">starter repo</a> Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> or <a href="https://github.com/topoteretes/cognee-starter">starter repo</a>
## Contributing ## Contributing
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information. Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
@ -116,12 +124,14 @@ Example output:
Natural Language Processing (NLP) is a cross-disciplinary and interdisciplinary field that involves computer science and information retrieval. It focuses on the interaction between computers and human language, enabling machines to understand and process natural language. Natural Language Processing (NLP) is a cross-disciplinary and interdisciplinary field that involves computer science and information retrieval. It focuses on the interaction between computers and human language, enabling machines to understand and process natural language.
``` ```
Graph visualization:
<a href="https://rawcdn.githack.com/topoteretes/cognee/refs/heads/add-visualization-readme/assets/graph_visualization.html"><img src="assets/graph_visualization.png" width="100%" alt="Graph Visualization"></a>
Open in [browser](https://rawcdn.githack.com/topoteretes/cognee/refs/heads/add-visualization-readme/assets/graph_visualization.html).
For more advanced usage, have a look at our <a href="https://docs.cognee.ai"> documentation</a>. ### cognee UI
You can also cognify your files and query using cognee UI.
<img src="assets/cognee-ui-2.webp" width="100%" alt="Cognee UI 2"></a>
Try cognee UI out locally [here](https://docs.cognee.ai/how-to-guides/cognee-ui).
## Understand our architecture ## Understand our architecture

BIN
assets/cognee-ui-1.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 878 KiB

BIN
assets/cognee-ui-2.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 936 KiB

View file

@ -1,128 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="https://d3js.org/d3.v5.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
</style>
</head>
<body>
<svg></svg>
<script>
var nodes = [{"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["text"]}, "type": "DocumentChunk", "text": "Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.", "chunk_size": 34, "chunk_index": 0, "cut_type": "sentence_end", "id": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "color": "#801212", "name": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "natural language processing", "description": "An interdisciplinary subfield of computer science and information retrieval.", "ontology_valid": false, "id": "bc338a39-64d6-549a-acec-da60846dd90d", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "EntityType", "name": "concept", "description": "concept", "ontology_valid": false, "id": "dd9713b7-dc20-5101-aad0-1c4216811147", "color": "#6510f4"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "information retrieval", "description": "The activity of obtaining information system resources that are relevant to an information need.", "ontology_valid": false, "id": "02bdab9a-0981-518c-a0d4-1684e0329447", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "EntityType", "name": "field", "description": "field", "ontology_valid": false, "id": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "color": "#6510f4"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "computer science", "description": "The study of computers and computational systems.", "ontology_valid": false, "id": "6218dbab-eb6a-5759-a864-b3419755ffe0", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "TextDocument", "name": "text_46d2fce36f0f7b6ebc0575e353fdba5c", "raw_data_location": "/Users/handekafkas/Documents/local-code/new-cognee/cognee/cognee/.data_storage/data/text_46d2fce36f0f7b6ebc0575e353fdba5c.txt", "external_metadata": "{}", "mime_type": "text/plain", "id": "c07949fe-5a9f-53b9-ac90-5cb48a8a4303", "color": "#D3D3D3"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["text"]}, "type": "TextSummary", "text": "Natural language processing (NLP) is a cross-disciplinary area of computer science and information extraction.", "id": "9da41e72-8150-5055-9217-eea49d1bc447", "color": "#1077f4", "name": "9da41e72-8150-5055-9217-eea49d1bc447"}];
var links = [{"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "bc338a39-64d6-549a-acec-da60846dd90d", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "02bdab9a-0981-518c-a0d4-1684e0329447", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "6218dbab-eb6a-5759-a864-b3419755ffe0", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "c07949fe-5a9f-53b9-ac90-5cb48a8a4303", "relation": "is_part_of"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "dd9713b7-dc20-5101-aad0-1c4216811147", "relation": "is_a"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "6218dbab-eb6a-5759-a864-b3419755ffe0", "relation": "is_a_subfield_of"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "02bdab9a-0981-518c-a0d4-1684e0329447", "relation": "is_a_subfield_of"}, {"source": "02bdab9a-0981-518c-a0d4-1684e0329447", "target": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "relation": "is_a"}, {"source": "6218dbab-eb6a-5759-a864-b3419755ffe0", "target": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "relation": "is_a"}, {"source": "9da41e72-8150-5055-9217-eea49d1bc447", "target": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "relation": "made_from"}];
var svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
var container = svg.append("g");
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-275))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.attr("stroke-width", 2);
var edgeLabels = container.append("g")
.attr("class", "edge-labels")
.selectAll("text")
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.text(d => d.relation);
var nodeGroup = container.append("g")
.attr("class", "nodes")
.selectAll("g")
.data(nodes)
.enter().append("g");
var node = nodeGroup.append("circle")
.attr("r", 13)
.attr("fill", d => d.color)
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
nodeGroup.append("text")
.attr("class", "node-label")
.attr("dy", 4)
.attr("text-anchor", "middle")
.text(d => d.name);
node.append("title").text(d => JSON.stringify(d));
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
edgeLabels
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
node.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeGroup.select("text")
.attr("x", d => d.x)
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
});
svg.call(d3.zoom().on("zoom", function() {
container.attr("transform", d3.event.transform);
}));
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(d) {
d.fx = d3.event.x;
d.fy = d3.event.y;
}
function dragended(d) {
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
window.addEventListener("resize", function() {
width = window.innerWidth;
height = window.innerHeight;
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
});
</script>
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.63 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
</svg>
</body>
</html>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 365 KiB

View file

@ -7,6 +7,7 @@ requires-python = ">=3.10"
dependencies = [ dependencies = [
"cognee[postgres,codegraph,gemini,huggingface]==0.1.39", "cognee[postgres,codegraph,gemini,huggingface]==0.1.39",
"fastmcp>=1.0",
"mcp==1.5.0", "mcp==1.5.0",
"uv>=0.6.3", "uv>=0.6.3",
] ]

View file

@ -1,253 +1,141 @@
import asyncio
import json import json
import os import os
import sys import sys
import argparse
import cognee import cognee
import asyncio
from cognee.shared.logging_utils import get_logger, get_log_file_location from cognee.shared.logging_utils import get_logger, get_log_file_location
import importlib.util import importlib.util
from contextlib import redirect_stdout from contextlib import redirect_stdout
# from PIL import Image as PILImage
import mcp.types as types import mcp.types as types
from mcp.server import Server, NotificationOptions from mcp.server import FastMCP
from mcp.server.models import InitializationOptions
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
mcp = Server("cognee") mcp = FastMCP("Cognee")
logger = get_logger() logger = get_logger()
log_file = get_log_file_location()
@mcp.list_tools() @mcp.tool()
async def list_tools() -> list[types.Tool]: async def cognify(text: str, graph_model_file: str = None, graph_model_name: str = None) -> list:
async def cognify_task(
text: str, graph_model_file: str = None, graph_model_name: str = None
) -> str:
"""Build knowledge graph from the input text"""
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
# As cognify is an async background job the output had to be redirected again.
with redirect_stdout(sys.stderr):
logger.info("Cognify process starting.")
if graph_model_file and graph_model_name:
graph_model = load_class(graph_model_file, graph_model_name)
else:
graph_model = KnowledgeGraph
await cognee.add(text)
try:
await cognee.cognify(graph_model=graph_model)
logger.info("Cognify process finished.")
except Exception as e:
logger.error("Cognify process failed.")
raise ValueError(f"Failed to cognify: {str(e)}")
asyncio.create_task(
cognify_task(
text=text,
graph_model_file=graph_model_file,
graph_model_name=graph_model_name,
)
)
text = (
f"Background process launched due to MCP timeout limitations.\n"
f"Average completion time is around 4 minutes.\n"
f"For current cognify status you can check the log file at: {log_file}"
)
return [ return [
types.Tool( types.TextContent(
name="cognify", type="text",
description="Cognifies text into knowledge graph", text=text,
inputSchema={ )
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to cognify",
},
"graph_model_file": {
"type": "string",
"description": "The path to the graph model file (Optional)",
},
"graph_model_name": {
"type": "string",
"description": "The name of the graph model (Optional)",
},
},
"required": ["text"],
},
),
types.Tool(
name="codify",
description="Transforms codebase into knowledge graph",
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
},
},
"required": ["repo_path"],
},
),
types.Tool(
name="search",
description="Searches for information in knowledge graph",
inputSchema={
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "The query to search for",
},
"search_type": {
"type": "string",
"description": "The type of search to perform (e.g., INSIGHTS, CODE)",
},
},
"required": ["search_query"],
},
),
types.Tool(
name="prune",
description="Prunes knowledge graph",
inputSchema={
"type": "object",
"properties": {},
},
),
] ]
@mcp.call_tool() @mcp.tool()
async def call_tools(name: str, arguments: dict) -> list[types.TextContent]: async def codify(repo_path: str) -> list:
try: async def codify_task(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
# As codify is an async background job the output had to be redirected again.
with redirect_stdout(sys.stderr):
logger.info("Codify process starting.")
results = []
async for result in run_code_graph_pipeline(repo_path, False):
results.append(result)
logger.info(result)
if all(results):
logger.info("Codify process finished succesfully.")
else:
logger.info("Codify process failed.")
asyncio.create_task(codify_task(repo_path))
text = (
f"Background process launched due to MCP timeout limitations.\n"
f"Average completion time is around 4 minutes.\n"
f"For current codify status you can check the log file at: {log_file}"
)
return [
types.TextContent(
type="text",
text=text,
)
]
@mcp.tool()
async def search(search_query: str, search_type: str) -> list:
async def search_task(search_query: str, search_type: str) -> str:
"""Search the knowledge graph"""
# NOTE: MCP uses stdout to communicate, we must redirect all output # NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr. # going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
log_file = get_log_file_location() search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
if name == "cognify":
asyncio.create_task(
cognify(
text=arguments["text"],
graph_model_file=arguments.get("graph_model_file"),
graph_model_name=arguments.get("graph_model_name"),
)
)
text = (
f"Background process launched due to MCP timeout limitations.\n"
f"Average completion time is around 4 minutes.\n"
f"For current cognify status you can check the log file at: {log_file}"
)
return [
types.TextContent(
type="text",
text=text,
)
]
if name == "codify":
asyncio.create_task(codify(arguments.get("repo_path")))
text = (
f"Background process launched due to MCP timeout limitations.\n"
f"Average completion time is around 4 minutes.\n"
f"For current codify status you can check the log file at: {log_file}"
)
return [
types.TextContent(
type="text",
text=text,
)
]
elif name == "search":
search_results = await search(arguments["search_query"], arguments["search_type"])
return [types.TextContent(type="text", text=search_results)]
elif name == "prune":
await prune()
return [types.TextContent(type="text", text="Pruned")]
except Exception as e:
logger.error(f"Error calling tool '{name}': {str(e)}")
return [types.TextContent(type="text", text=f"Error calling tool '{name}': {str(e)}")]
async def cognify(text: str, graph_model_file: str = None, graph_model_name: str = None) -> str:
"""Build knowledge graph from the input text"""
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
# As cognify is an async background job the output had to be redirected again.
with redirect_stdout(sys.stderr):
logger.info("Cognify process starting.")
if graph_model_file and graph_model_name:
graph_model = load_class(graph_model_file, graph_model_name)
else:
graph_model = KnowledgeGraph
await cognee.add(text)
try:
await cognee.cognify(graph_model=graph_model)
logger.info("Cognify process finished.")
except Exception as e:
logger.error("Cognify process failed.")
raise ValueError(f"Failed to cognify: {str(e)}")
async def codify(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
# As codify is an async background job the output had to be redirected again.
with redirect_stdout(sys.stderr):
logger.info("Codify process starting.")
results = []
async for result in run_code_graph_pipeline(repo_path, False):
results.append(result)
logger.info(result)
if all(results):
logger.info("Codify process finished succesfully.")
else:
logger.info("Codify process failed.")
async def search(search_query: str, search_type: str) -> str:
"""Search the knowledge graph"""
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
)
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION":
return search_results[0]
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
else:
return str(search_results)
async def prune():
"""Reset the knowledge graph"""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
async def main():
try:
from mcp.server.stdio import stdio_server
logger.info("Cognee MCP server started...")
async with stdio_server() as (read_stream, write_stream):
await mcp.run(
read_stream=read_stream,
write_stream=write_stream,
initialization_options=InitializationOptions(
server_name="cognee",
server_version="0.1.0",
capabilities=mcp.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
raise_exceptions=True,
) )
logger.info("Cognee MCP server closed.") if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif (
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
):
return search_results[0]
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
else:
return str(search_results)
except Exception as e: search_results = await search_task(search_query, search_type)
logger.error(f"Server failed to start: {str(e)}", exc_info=True) return [types.TextContent(type="text", text=search_results)]
raise
# async def visualize() -> Image: @mcp.tool()
# """Visualize the knowledge graph""" async def prune():
# try: """Reset the knowledge graph"""
# image_path = await cognee.visualize_graph() with redirect_stdout(sys.stderr):
await cognee.prune.prune_data()
# img = PILImage.open(image_path) await cognee.prune.prune_system(metadata=True)
# return Image(data=img.tobytes(), format="png") return [types.TextContent(type="text", text="Pruned")]
# except (FileNotFoundError, IOError, ValueError) as e:
# raise ValueError(f"Failed to create visualization: {str(e)}")
def node_to_string(node): def node_to_string(node):
@ -265,6 +153,7 @@ def retrieved_edges_to_string(search_results):
relationship_type = edge["relationship_name"] relationship_type = edge["relationship_name"]
edge_str = f"{node_to_string(node1)} {relationship_type} {node_to_string(node2)}" edge_str = f"{node_to_string(node1)} {relationship_type} {node_to_string(node2)}"
edge_strings.append(edge_str) edge_strings.append(edge_str)
return "\n".join(edge_strings) return "\n".join(edge_strings)
@ -279,32 +168,31 @@ def load_class(model_file, model_name):
return model_class return model_class
# def get_freshest_png(directory: str) -> Image: async def main():
# if not os.path.exists(directory): parser = argparse.ArgumentParser()
# raise FileNotFoundError(f"Directory {directory} does not exist")
# # List all files in 'directory' that end with .png parser.add_argument(
# files = [f for f in os.listdir(directory) if f.endswith(".png")] "--transport",
# if not files: choices=["sse", "stdio"],
# raise FileNotFoundError("No PNG files found in the given directory.") default="stdio",
help="Transport to use for communication with the client. (default: stdio)",
)
# # Sort by integer value of the filename (minus the '.png') args = parser.parse_args()
# # Example filename: 1673185134.png -> integer 1673185134
# try:
# files_sorted = sorted(files, key=lambda x: int(x.replace(".png", "")))
# except ValueError as e:
# raise ValueError("Invalid PNG filename format. Expected timestamp format.") from e
# # The "freshest" file has the largest timestamp logger.info(f"Starting MCP server with transport: {args.transport}")
# freshest_filename = files_sorted[-1] if args.transport == "stdio":
# freshest_path = os.path.join(directory, freshest_filename) await mcp.run_stdio_async()
elif args.transport == "sse":
logger.info(
f"Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}"
)
await mcp.run_sse_async()
# # Open the image with PIL and return the PIL Image object
# try:
# return PILImage.open(freshest_path)
# except (IOError, OSError) as e:
# raise IOError(f"Failed to open PNG file {freshest_path}") from e
if __name__ == "__main__": if __name__ == "__main__":
# Initialize and run the server try:
asyncio.run(main()) asyncio.run(main())
except Exception as e:
logger.error(f"Error initializing Cognee MCP server: {str(e)}")
raise

4946
cognee-mcp/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,18 +1,17 @@
import os import os
import pathlib import pathlib
import asyncio import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.api.v1.search import SearchType, search from cognee.api.v1.search import SearchType, search
from cognee.api.v1.visualize.visualize import visualize_graph from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.base_config import get_base_config
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.task import Task from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import render_graph
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data from cognee.tasks.ingestion import ingest_data
@ -22,11 +21,7 @@ from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text from cognee.tasks.summarization import summarize_text
from cognee.infrastructure.llm import get_max_chunk_tokens from cognee.infrastructure.llm import get_max_chunk_tokens
monitoring = get_base_config().monitoring_tool observe = get_observe()
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe
logger = get_logger("code_graph_pipeline") logger = get_logger("code_graph_pipeline")

View file

@ -1,14 +1,14 @@
import os import os
from typing import Optional from typing import Optional
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import MonitoringTool from cognee.modules.observability.observers import Observer
from pydantic_settings import BaseSettings, SettingsConfigDict
class BaseConfig(BaseSettings): class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage") data_root_directory: str = get_absolute_path(".data_storage")
monitoring_tool: object = MonitoringTool.LANGFUSE monitoring_tool: object = Observer.LANGFUSE
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME") graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD") graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")

View file

@ -12,13 +12,22 @@ class CogneeApiError(Exception):
message: str = "Service is unavailable.", message: str = "Service is unavailable.",
name: str = "Cognee", name: str = "Cognee",
status_code=status.HTTP_418_IM_A_TEAPOT, status_code=status.HTTP_418_IM_A_TEAPOT,
log=True,
log_level="ERROR",
): ):
self.message = message self.message = message
self.name = name self.name = name
self.status_code = status_code self.status_code = status_code
# Automatically log the exception details # Automatically log the exception details
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})") if log and (log_level == "ERROR"):
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
elif log and (log_level == "WARNING"):
logger.warning(f"{self.name}: {self.message} (Status code: {self.status_code})")
elif log and (log_level == "INFO"):
logger.info(f"{self.name}: {self.message} (Status code: {self.status_code})")
elif log and (log_level == "DEBUG"):
logger.debug(f"{self.name}: {self.message} (Status code: {self.status_code})")
super().__init__(self.message, self.name) super().__init__(self.message, self.name)

View file

@ -58,7 +58,7 @@ def record_graph_changes(func):
session.add(relationship) session.add(relationship)
await session.flush() await session.flush()
except Exception as e: except Exception as e:
logger.error(f"Error adding relationship: {e}") logger.debug(f"Error adding relationship: {e}")
await session.rollback() await session.rollback()
continue continue
@ -78,14 +78,14 @@ def record_graph_changes(func):
session.add(relationship) session.add(relationship)
await session.flush() await session.flush()
except Exception as e: except Exception as e:
logger.error(f"Error adding relationship: {e}") logger.debug(f"Error adding relationship: {e}")
await session.rollback() await session.rollback()
continue continue
try: try:
await session.commit() await session.commit()
except Exception as e: except Exception as e:
logger.error(f"Error committing session: {e}") logger.debug(f"Error committing session: {e}")
return result return result

View file

@ -42,7 +42,7 @@ class NetworkXAdapter(GraphDBInterface):
async def query(self, query: str, params: dict): async def query(self, query: str, params: dict):
pass pass
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: UUID) -> bool:
return self.graph.has_node(node_id) return self.graph.has_node(node_id)
async def add_node(self, node: DataPoint) -> None: async def add_node(self, node: DataPoint) -> None:
@ -136,7 +136,7 @@ class NetworkXAdapter(GraphDBInterface):
logger.error(f"Failed to add edges: {e}") logger.error(f"Failed to add edges: {e}")
raise raise
async def get_edges(self, node_id: str): async def get_edges(self, node_id: UUID):
return list(self.graph.in_edges(node_id, data=True)) + list( return list(self.graph.in_edges(node_id, data=True)) + list(
self.graph.out_edges(node_id, data=True) self.graph.out_edges(node_id, data=True)
) )
@ -174,13 +174,13 @@ class NetworkXAdapter(GraphDBInterface):
return disconnected_nodes return disconnected_nodes
async def extract_node(self, node_id: str) -> dict: async def extract_node(self, node_id: UUID) -> dict:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
return self.graph.nodes[node_id] return self.graph.nodes[node_id]
return None return None
async def extract_nodes(self, node_ids: List[str]) -> List[dict]: async def extract_nodes(self, node_ids: List[UUID]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)] return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list: async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
@ -215,7 +215,7 @@ class NetworkXAdapter(GraphDBInterface):
return nodes return nodes
async def get_neighbors(self, node_id: str) -> list: async def get_neighbors(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id): if not self.graph.has_node(node_id):
return [] return []
@ -264,7 +264,7 @@ class NetworkXAdapter(GraphDBInterface):
return connections return connections
async def remove_connection_to_predecessors_of( async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str self, node_ids: list[UUID], edge_label: str
) -> None: ) -> None:
for node_id in node_ids: for node_id in node_ids:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
@ -275,7 +275,7 @@ class NetworkXAdapter(GraphDBInterface):
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def remove_connection_to_successors_of( async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str self, node_ids: list[UUID], edge_label: str
) -> None: ) -> None:
for node_id in node_ids: for node_id in node_ids:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
@ -621,12 +621,12 @@ class NetworkXAdapter(GraphDBInterface):
nodes.append(node_data) nodes.append(node_data)
return nodes return nodes
async def get_node(self, node_id: str) -> dict: async def get_node(self, node_id: UUID) -> dict:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
return self.graph.nodes[node_id] return self.graph.nodes[node_id]
return None return None
async def get_nodes(self, node_ids: List[str] = None) -> List[dict]: async def get_nodes(self, node_ids: List[UUID] = None) -> List[dict]:
if node_ids is None: if node_ids is None:
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)] return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
return [ return [

View file

@ -69,7 +69,7 @@ class SQLAlchemyAdapter:
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"): async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite": if self.engine.dialect.name == "sqlite":
# SQLite doesnt support schema namespaces and the CASCADE keyword. # SQLite doesn't support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation. # However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";')) await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";'))
else: else:
@ -327,10 +327,10 @@ class SQLAlchemyAdapter:
file.write("") file.write("")
else: else:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information # Create a MetaData instance to load table information
metadata = MetaData() metadata = MetaData()
# Drop all tables from all schemas # Drop all tables from the public schema
schema_list = ["public", "public_staging"]
for schema_name in schema_list: for schema_name in schema_list:
# Load the schema information into the MetaData object # Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name) await connection.run_sync(metadata.reflect, schema=schema_name)

View file

@ -6,8 +6,9 @@ from chromadb import AsyncHttpClient, Settings
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
@ -108,9 +109,7 @@ class ChromaDBAdapter(VectorDBInterface):
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
client = await self.get_connection() collections = await self.get_collection_names()
collections = await client.list_collections()
# In ChromaDB v0.6.0, list_collections returns collection names directly
return collection_name in collections return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema=None):
@ -119,13 +118,17 @@ class ChromaDBAdapter(VectorDBInterface):
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"}) await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def get_collection(self, collection_name: str) -> AsyncHttpClient:
client = await self.get_connection()
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection(collection_name) raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
collection = await client.get_collection(collection_name) client = await self.get_connection()
return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
await self.create_collection(collection_name)
collection = await self.get_collection(collection_name)
texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points] texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts) embeddings = await self.embed_data(texts)
@ -161,8 +164,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""Retrieve data points by their IDs from a collection.""" """Retrieve data points by their IDs from a collection."""
client = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await client.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"]) results = await collection.get(ids=data_point_ids, include=["metadatas"])
return [ return [
@ -174,62 +176,12 @@ class ChromaDBAdapter(VectorDBInterface):
for id, metadata in zip(results["ids"], results["metadatas"]) for id, metadata in zip(results["ids"], results["metadatas"])
] ]
async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
"""Calculate distance between query and all elements in a collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
client = await self.get_connection()
try:
collection = await client.get_collection(collection_name)
collection_count = await collection.count()
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances"],
n_results=collection_count,
)
result_values = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
result_values.append(
{
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
)
normalized_values = normalize_distances(result_values)
scored_results = []
for i, result in enumerate(result_values):
scored_results.append(
ScoredResult(
id=result["id"],
payload=result["payload"],
score=normalized_values[i],
)
)
return scored_results
except Exception:
return []
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 5, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ):
@ -241,8 +193,10 @@ class ChromaDBAdapter(VectorDBInterface):
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try: try:
client = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await client.get_collection(collection_name)
if limit == 0:
limit = await collection.count()
results = await collection.query( results = await collection.query(
query_embeddings=[query_vector], query_embeddings=[query_vector],
@ -296,8 +250,7 @@ class ChromaDBAdapter(VectorDBInterface):
"""Perform multiple searches in a single request for efficiency.""" """Perform multiple searches in a single request for efficiency."""
query_vectors = await self.embed_data(query_texts) query_vectors = await self.embed_data(query_texts)
client = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await client.get_collection(collection_name)
results = await collection.query( results = await collection.query(
query_embeddings=query_vectors, query_embeddings=query_vectors,
@ -346,15 +299,14 @@ class ChromaDBAdapter(VectorDBInterface):
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""Remove data points from a collection by their IDs.""" """Remove data points from a collection by their IDs."""
client = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await client.get_collection(collection_name)
await collection.delete(ids=data_point_ids) await collection.delete(ids=data_point_ids)
return True return True
async def prune(self): async def prune(self):
"""Delete all collections in the ChromaDB database.""" """Delete all collections in the ChromaDB database."""
client = await self.get_connection() client = await self.get_connection()
collections = await client.list_collections() collections = await self.list_collections()
for collection_name in collections: for collection_name in collections:
await client.delete_collection(collection_name) await client.delete_collection(collection_name)
return True return True
@ -362,4 +314,8 @@ class ChromaDBAdapter(VectorDBInterface):
async def get_collection_names(self): async def get_collection_names(self):
"""Get a list of all collection names in the database.""" """Get a list of all collection names in the database."""
client = await self.get_connection() client = await self.get_connection()
return await client.list_collections() collections = await client.list_collections()
return [
collection.name if hasattr(collection, "name") else collection["name"]
for collection in collections
]

View file

@ -6,7 +6,9 @@ class CollectionNotFoundError(CriticalError):
def __init__( def __init__(
self, self,
message, message,
name: str = "DatabaseNotCreatedError", name: str = "CollectionNotFoundError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY, status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
log=True,
log_level="ERROR",
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code, log, log_level)

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import lancedb import lancedb
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from pydantic import BaseModel from pydantic import BaseModel
@ -76,9 +75,14 @@ class LanceDBAdapter(VectorDBInterface):
exist_ok=True, exist_ok=True,
) )
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def get_collection(self, collection_name: str):
connection = await self.get_connection() if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection()
return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
payload_schema = type(data_points[0]) payload_schema = type(data_points[0])
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
@ -87,7 +91,7 @@ class LanceDBAdapter(VectorDBInterface):
payload_schema, payload_schema,
) )
collection = await connection.open_table(collection_name) collection = await self.get_collection(collection_name)
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points] [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
@ -125,8 +129,7 @@ class LanceDBAdapter(VectorDBInterface):
) )
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await connection.open_table(collection_name)
if len(data_point_ids) == 1: if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
@ -142,48 +145,12 @@ class LanceDBAdapter(VectorDBInterface):
for result in results.to_dict("index").values() for result in results.to_dict("index").values()
] ]
async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
try:
collection = await connection.open_table(collection_name)
collection_size = await collection.count_rows()
results = (
await collection.vector_search(query_vector).limit(collection_size).to_pandas()
)
result_values = list(results.to_dict("index").values())
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
except ValueError:
# Ignore if collection doesn't exist
return []
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 5, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ):
@ -193,12 +160,10 @@ class LanceDBAdapter(VectorDBInterface):
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection() collection = await self.get_collection(collection_name)
try: if limit == 0:
collection = await connection.open_table(collection_name) limit = await collection.count_rows()
except ValueError:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
results = await collection.vector_search(query_vector).limit(limit).to_pandas() results = await collection.vector_search(query_vector).limit(limit).to_pandas()
@ -242,8 +207,7 @@ class LanceDBAdapter(VectorDBInterface):
def delete_data_points(self, collection_name: str, data_point_ids: list[str]): def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def _delete_data_points(): async def _delete_data_points():
connection = await self.get_connection() collection = await self.get_collection(collection_name)
collection = await connection.open_table(collection_name)
# Delete one at a time to avoid commit conflicts # Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids: for data_point_id in data_point_ids:
@ -288,7 +252,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
for collection_name in collection_names: for collection_name in collection_names:
collection = await connection.open_table(collection_name) collection = await self.get_collection(collection_name)
await collection.delete("id IS NOT NULL") await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name) await connection.drop_table(collection_name)

View file

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from cognee.shared.logging_utils import get_logger from uuid import UUID
from typing import List, Optional from typing import List, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -96,7 +97,7 @@ class MilvusAdapter(VectorDBInterface):
raise e raise e
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from pymilvus import MilvusException from pymilvus import MilvusException, exceptions
client = self.get_milvus_client() client = self.get_milvus_client()
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
@ -118,6 +119,10 @@ class MilvusAdapter(VectorDBInterface):
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'." f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
) )
return result return result
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e: except MilvusException as e:
logger.error( logger.error(
f"Error inserting data points into collection '{collection_name}': {str(e)}" f"Error inserting data points into collection '{collection_name}': {str(e)}"
@ -140,8 +145,8 @@ class MilvusAdapter(VectorDBInterface):
collection_name = f"{index_name}_{index_property_name}" collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points) await self.create_data_points(collection_name, formatted_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException from pymilvus import MilvusException, exceptions
client = self.get_milvus_client() client = self.get_milvus_client()
try: try:
@ -153,6 +158,10 @@ class MilvusAdapter(VectorDBInterface):
output_fields=["*"], output_fields=["*"],
) )
return results return results
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e: except MilvusException as e:
logger.error( logger.error(
f"Error retrieving data points from collection '{collection_name}': {str(e)}" f"Error retrieving data points from collection '{collection_name}': {str(e)}"
@ -164,10 +173,10 @@ class MilvusAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = 5, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
): ):
from pymilvus import MilvusException from pymilvus import MilvusException, exceptions
client = self.get_milvus_client() client = self.get_milvus_client()
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
@ -184,7 +193,7 @@ class MilvusAdapter(VectorDBInterface):
collection_name=collection_name, collection_name=collection_name,
data=[query_vector], data=[query_vector],
anns_field="vector", anns_field="vector",
limit=limit, limit=limit if limit > 0 else None,
output_fields=output_fields, output_fields=output_fields,
search_params={ search_params={
"metric_type": "COSINE", "metric_type": "COSINE",
@ -199,6 +208,10 @@ class MilvusAdapter(VectorDBInterface):
) )
for result in results[0] for result in results[0]
] ]
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e: except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}") logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e raise e
@ -220,7 +233,7 @@ class MilvusAdapter(VectorDBInterface):
] ]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException from pymilvus import MilvusException
client = self.get_milvus_client() client = self.get_milvus_client()

View file

@ -8,19 +8,18 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from .serialize_data import serialize_data
from ..utils import normalize_distances from ..utils import normalize_distances
from ..models.ScoredResult import ScoredResult
from ..exceptions import CollectionNotFoundError
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from .serialize_data import serialize_data
logger = get_logger("PGVectorAdapter") logger = get_logger("PGVectorAdapter")
@ -180,7 +179,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if collection_name in metadata.tables: if collection_name in metadata.tables:
return metadata.tables[collection_name] return metadata.tables[collection_name]
else: else:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") raise CollectionNotFoundError(
f"Collection '{collection_name}' not found!", log_level="DEBUG"
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
@ -197,60 +198,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
for result in results for result in results
] ]
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
).order_by("similarity")
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(id=parse_id(str(row.id)), payload=row.payload, score=row.similarity)
for row in vector_list
]
except EntityNotFoundError:
# Ignore if collection does not exist
return []
except CollectionNotFoundError:
# Ignore if collection does not exist
return []
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = 5, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ) -> List[ScoredResult]:
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
@ -262,24 +215,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name) PGVectorDataPoint = await self.get_table(collection_name)
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = [] closest_items = []
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector # Find closest vectors to query_vector
closest_items = await session.execute( closest_items = await session.execute(query)
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
)
.order_by("similarity")
.limit(limit)
)
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization # Extract distances and find min/max for normalization
for vector in closest_items: for vector in closest_items.all():
vector_list.append( vector_list.append(
{ {
"id": parse_id(str(vector.id)), "id": parse_id(str(vector.id)),
@ -288,6 +243,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
} }
) )
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list # Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list) normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)): for i in range(0, len(normalized_values)):

View file

@ -1,12 +1,12 @@
from cognee.shared.logging_utils import get_logger
from typing import Dict, List, Optional from typing import Dict, List, Optional
from cognee.infrastructure.engine.utils import parse_id
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine.utils import parse_id
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
@ -97,6 +97,8 @@ class QDrantAdapter(VectorDBInterface):
await client.close() await client.close()
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from qdrant_client.http.exceptions import UnexpectedResponse
client = self.get_qdrant_client() client = self.get_qdrant_client()
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
@ -114,6 +116,13 @@ class QDrantAdapter(VectorDBInterface):
try: try:
client.upload_points(collection_name=collection_name, points=points) client.upload_points(collection_name=collection_name, points=points)
except UnexpectedResponse as error:
if "Collection not found" in str(error):
raise CollectionNotFoundError(
message=f"Collection {collection_name} not found!"
) from error
else:
raise error
except Exception as error: except Exception as error:
logger.error("Error uploading data points to Qdrant: %s", str(error)) logger.error("Error uploading data points to Qdrant: %s", str(error))
raise error raise error
@ -143,19 +152,22 @@ class QDrantAdapter(VectorDBInterface):
await client.close() await client.close()
return results return results
async def get_distance_from_collection_elements( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ):
if query_text is None and query_vector is None: from qdrant_client.http.exceptions import UnexpectedResponse
raise ValueError("One of query_text or query_vector must be provided!")
client = self.get_qdrant_client() if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
try: try:
client = self.get_qdrant_client()
results = await client.search( results = await client.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=models.NamedVector( query_vector=models.NamedVector(
@ -164,9 +176,12 @@ class QDrantAdapter(VectorDBInterface):
if query_vector is not None if query_vector is not None
else (await self.embed_data([query_text]))[0], else (await self.embed_data([query_text]))[0],
), ),
limit=limit if limit > 0 else None,
with_vectors=with_vector, with_vectors=with_vector,
) )
await client.close()
return [ return [
ScoredResult( ScoredResult(
id=parse_id(result.id), id=parse_id(result.id),
@ -178,51 +193,16 @@ class QDrantAdapter(VectorDBInterface):
) )
for result in results for result in results
] ]
except ValueError: except UnexpectedResponse as error:
# Ignore if the collection doesn't exist if "Collection not found" in str(error):
return [] raise CollectionNotFoundError(
message=f"Collection {collection_name} not found!"
) from error
else:
raise error
finally: finally:
await client.close() await client.close()
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
client = self.get_qdrant_client()
results = await client.search(
collection_name=collection_name,
query_vector=models.NamedVector(
name="text",
vector=query_vector
if query_vector is not None
else (await self.embed_data([query_text]))[0],
),
limit=limit,
with_vectors=with_vector,
)
await client.close()
return [
ScoredResult(
id=parse_id(result.id),
payload={
**result.payload,
"id": parse_id(result.id),
},
score=1 - result.score,
)
for result in results
]
async def batch_search( async def batch_search(
self, self,
collection_name: str, collection_name: str,

View file

@ -1,10 +1,10 @@
import asyncio
from cognee.shared.logging_utils import get_logger
from typing import List, Optional from typing import List, Optional
from cognee.shared.logging_utils import get_logger
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -34,21 +34,23 @@ class WeaviateAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs( self.client = weaviate.use_async_with_weaviate_cloud(
cluster_url=url, cluster_url=url,
auth_credentials=weaviate.auth.AuthApiKey(api_key), auth_credentials=weaviate.auth.AuthApiKey(api_key),
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)), additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)),
) )
async def get_client(self):
await self.client.connect()
return self.client
async def embed_data(self, data: List[str]) -> List[float]: async def embed_data(self, data: List[str]) -> List[float]:
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
future = asyncio.Future() client = await self.get_client()
return await client.collections.exists(collection_name)
future.set_result(self.client.collections.exists(collection_name))
return await future
async def create_collection( async def create_collection(
self, self,
@ -57,26 +59,25 @@ class WeaviateAdapter(VectorDBInterface):
): ):
import weaviate.classes.config as wvcc import weaviate.classes.config as wvcc
future = asyncio.Future() if not await self.has_collection(collection_name):
client = await self.get_client()
if not self.client.collections.exists(collection_name): return await client.collections.create(
future.set_result( name=collection_name,
self.client.collections.create( properties=[
name=collection_name, wvcc.Property(
properties=[ name="text", data_type=wvcc.DataType.TEXT, skip_vectorization=True
wvcc.Property( )
name="text", data_type=wvcc.DataType.TEXT, skip_vectorization=True ],
)
],
)
) )
else: else:
future.set_result(self.get_collection(collection_name)) return await self.get_collection(collection_name)
return await future async def get_collection(self, collection_name: str):
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found.")
def get_collection(self, collection_name: str): client = await self.get_client()
return self.client.collections.get(collection_name) return client.collections.get(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from weaviate.classes.data import DataObject from weaviate.classes.data import DataObject
@ -97,29 +98,30 @@ class WeaviateAdapter(VectorDBInterface):
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points] data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
collection = self.get_collection(collection_name) collection = await self.get_collection(collection_name)
try: try:
if len(data_points) > 1: if len(data_points) > 1:
with collection.batch.dynamic() as batch: return await collection.data.insert_many(data_points)
for data_point in data_points: # with collection.batch.dynamic() as batch:
batch.add_object( # for data_point in data_points:
uuid=data_point.uuid, # batch.add_object(
vector=data_point.vector, # uuid=data_point.uuid,
properties=data_point.properties, # vector=data_point.vector,
references=data_point.references, # properties=data_point.properties,
) # references=data_point.references,
# )
else: else:
data_point: DataObject = data_points[0] data_point: DataObject = data_points[0]
if collection.data.exists(data_point.uuid): if collection.data.exists(data_point.uuid):
return collection.data.update( return await collection.data.update(
uuid=data_point.uuid, uuid=data_point.uuid,
vector=data_point.vector, vector=data_point.vector,
properties=data_point.properties, properties=data_point.properties,
references=data_point.references, references=data_point.references,
) )
else: else:
return collection.data.insert( return await collection.data.insert(
uuid=data_point.uuid, uuid=data_point.uuid,
vector=data_point.vector, vector=data_point.vector,
properties=data_point.properties, properties=data_point.properties,
@ -130,12 +132,12 @@ class WeaviateAdapter(VectorDBInterface):
raise error raise error
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}") return await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: list[DataPoint]
): ):
await self.create_data_points( return await self.create_data_points(
f"{index_name}_{index_property_name}", f"{index_name}_{index_property_name}",
[ [
IndexSchema( IndexSchema(
@ -149,9 +151,8 @@ class WeaviateAdapter(VectorDBInterface):
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
future = asyncio.Future() collection = await self.get_collection(collection_name)
data_points = await collection.query.fetch_objects(
data_points = self.get_collection(collection_name).query.fetch_objects(
filters=Filter.by_id().contains_any(data_point_ids) filters=Filter.by_id().contains_any(data_point_ids)
) )
@ -160,30 +161,32 @@ class WeaviateAdapter(VectorDBInterface):
data_point.id = data_point.uuid data_point.id = data_point.uuid
del data_point.properties del data_point.properties
future.set_result(data_points.objects) return data_points.objects
return await future async def search(
async def get_distance_from_collection_elements(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ):
import weaviate.classes as wvc import weaviate.classes as wvc
import weaviate.exceptions import weaviate.exceptions
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_vector is None: if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0] query_vector = (await self.embed_data([query_text]))[0]
collection = await self.get_collection(collection_name)
try: try:
search_result = self.get_collection(collection_name).query.hybrid( search_result = await collection.query.hybrid(
query=None, query=None,
vector=query_vector, vector=query_vector,
limit=limit if limit > 0 else None,
include_vector=with_vector, include_vector=with_vector,
return_metadata=wvc.query.MetadataQuery(score=True), return_metadata=wvc.query.MetadataQuery(score=True),
) )
@ -196,43 +199,10 @@ class WeaviateAdapter(VectorDBInterface):
) )
for result in search_result.objects for result in search_result.objects
] ]
except weaviate.exceptions.UnexpectedStatusCodeError: except weaviate.exceptions.WeaviateInvalidInputError:
# Ignore if the collection doesn't exist # Ignore if the collection doesn't exist
return [] return []
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
with_vector: bool = False,
):
import weaviate.classes as wvc
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0]
search_result = self.get_collection(collection_name).query.hybrid(
query=None,
vector=query_vector,
limit=limit,
include_vector=with_vector,
return_metadata=wvc.query.MetadataQuery(score=True),
)
return [
ScoredResult(
id=parse_id(str(result.uuid)),
payload=result.properties,
score=1 - float(result.metadata.score),
)
for result in search_result.objects
]
async def batch_search( async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
): ):
@ -248,14 +218,13 @@ class WeaviateAdapter(VectorDBInterface):
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
future = asyncio.Future() collection = await self.get_collection(collection_name)
result = await collection.data.delete_many(
result = self.get_collection(collection_name).data.delete_many(
filters=Filter.by_id().contains_any(data_point_ids) filters=Filter.by_id().contains_any(data_point_ids)
) )
future.set_result(result)
return await future return result
async def prune(self): async def prune(self):
self.client.collections.delete_all() client = await self.get_client()
await client.collections.delete_all()

View file

@ -1,9 +1,10 @@
from typing import Type, Optional
from pydantic import BaseModel
from cognee.shared.logging_utils import get_logger
import litellm import litellm
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError from litellm import acompletion, JSONSchemaValidationError
from cognee.shared.data_models import MonitoringTool
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
@ -11,14 +12,9 @@ from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async, rate_limit_async,
sleep_and_retry_async, sleep_and_retry_async,
) )
from cognee.base_config import get_base_config
logger = get_logger() logger = get_logger()
observe = get_observe()
monitoring = get_base_config().monitoring_tool
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe
class GeminiAdapter(LLMInterface): class GeminiAdapter(LLMInterface):

View file

@ -1,14 +1,12 @@
import os import os
import base64 import base64
from typing import Type
import litellm import litellm
import instructor import instructor
from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
from openai import ContentFilterFinishReasonError from openai import ContentFilterFinishReasonError
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.shared.data_models import MonitoringTool
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
@ -19,12 +17,9 @@ from cognee.infrastructure.llm.rate_limiter import (
sleep_and_retry_async, sleep_and_retry_async,
sleep_and_retry_sync, sleep_and_retry_sync,
) )
from cognee.base_config import get_base_config from cognee.modules.observability.get_observe import get_observe
monitoring = get_base_config().monitoring_tool observe = get_observe()
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):

View file

@ -128,8 +128,10 @@ class CogneeGraph(CogneeAbstractGraph):
if query_vector is None or len(query_vector) == 0: if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.") raise ValueError("Failed to generate query embedding.")
edge_distances = await vector_engine.get_distance_from_collection_elements( edge_distances = await vector_engine.search(
"EdgeType_relationship_name", query_text=query collection_name="EdgeType_relationship_name",
query_text=query,
limit=0,
) )
embedding_map = {result.payload["text"]: result.score for result in edge_distances} embedding_map = {result.payload["text"]: result.score for result in edge_distances}

View file

@ -0,0 +1,11 @@
from cognee.base_config import get_base_config
from .observers import Observer
def get_observe():
monitoring = get_base_config().monitoring_tool
if monitoring == Observer.LANGFUSE:
from langfuse.decorators import observe
return observe

View file

@ -0,0 +1,9 @@
from enum import Enum
class Observer(str, Enum):
"""Monitoring tools"""
LANGFUSE = "langfuse"
LLMLITE = "llmlite"
LANGSMITH = "langsmith"

View file

@ -20,7 +20,9 @@ from ..tasks.task import Task
logger = get_logger("run_tasks(tasks: [Task], data)") logger = get_logger("run_tasks(tasks: [Task], data)")
async def run_tasks_with_telemetry(tasks: list[Task], data, user: User, pipeline_name: str): async def run_tasks_with_telemetry(
tasks: list[Task], data, user: User, pipeline_name: str, context: dict = None
):
config = get_current_settings() config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1)) logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
@ -36,7 +38,7 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, user: User, pipeline
| config, | config,
) )
async for result in run_tasks_base(tasks, data, user): async for result in run_tasks_base(tasks, data, user, context):
yield result yield result
logger.info("Pipeline run completed: `%s`", pipeline_name) logger.info("Pipeline run completed: `%s`", pipeline_name)
@ -72,6 +74,7 @@ async def run_tasks(
data: Any = None, data: Any = None,
user: User = None, user: User = None,
pipeline_name: str = "unknown_pipeline", pipeline_name: str = "unknown_pipeline",
context: dict = None,
): ):
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name) pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
@ -82,7 +85,11 @@ async def run_tasks(
try: try:
async for _ in run_tasks_with_telemetry( async for _ in run_tasks_with_telemetry(
tasks=tasks, data=data, user=user, pipeline_name=pipeline_id tasks=tasks,
data=data,
user=user,
pipeline_name=pipeline_id,
context=context,
): ):
pass pass

View file

@ -14,6 +14,7 @@ async def handle_task(
leftover_tasks: list[Task], leftover_tasks: list[Task],
next_task_batch_size: int, next_task_batch_size: int,
user: User, user: User,
context: dict = None,
): ):
"""Handle common task workflow with logging, telemetry, and error handling around the core execution logic.""" """Handle common task workflow with logging, telemetry, and error handling around the core execution logic."""
task_type = running_task.task_type task_type = running_task.task_type
@ -27,9 +28,16 @@ async def handle_task(
}, },
) )
has_context = any(
[key == "context" for key in inspect.signature(running_task.executable).parameters.keys()]
)
if has_context:
args.append(context)
try: try:
async for result_data in running_task.execute(args, next_task_batch_size): async for result_data in running_task.execute(args, next_task_batch_size):
async for result in run_tasks_base(leftover_tasks, result_data, user): async for result in run_tasks_base(leftover_tasks, result_data, user, context):
yield result yield result
logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`") logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`")
@ -55,7 +63,7 @@ async def handle_task(
raise error raise error
async def run_tasks_base(tasks: list[Task], data=None, user: User = None): async def run_tasks_base(tasks: list[Task], data=None, user: User = None, context: dict = None):
"""Base function to execute tasks in a pipeline, handling task type detection and execution.""" """Base function to execute tasks in a pipeline, handling task type detection and execution."""
if len(tasks) == 0: if len(tasks) == 0:
yield data yield data
@ -68,5 +76,7 @@ async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
async for result in handle_task(running_task, args, leftover_tasks, next_task_batch_size, user): async for result in handle_task(
running_task, args, leftover_tasks, next_task_batch_size, user, context
):
yield result yield result

View file

@ -4,4 +4,4 @@ Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various data errors This module defines a set of exceptions for handling various data errors
""" """
from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError from .exceptions import SearchTypeNotSupported, CypherSearchError

View file

@ -2,16 +2,6 @@ from fastapi import status
from cognee.exceptions import CogneeApiError, CriticalError from cognee.exceptions import CogneeApiError, CriticalError
class CollectionDistancesNotFoundError(CogneeApiError):
def __init__(
self,
message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.",
name: str = "CollectionDistancesNotFoundError",
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class SearchTypeNotSupported(CogneeApiError): class SearchTypeNotSupported(CogneeApiError):
def __init__( def __init__(
self, self,

View file

@ -3,7 +3,6 @@ from collections import Counter
import string import string
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
@ -76,10 +75,7 @@ class GraphCompletionRetriever(BaseRetriever):
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> str:
"""Retrieves and resolves graph triplets into context.""" """Retrieves and resolves graph triplets into context."""
try: triplets = await self.get_triplets(query)
triplets = await self.get_triplets(query)
except EntityNotFoundError:
return ""
if len(triplets) == 0: if len(triplets) == 0:
return "" return ""

View file

@ -1,14 +1,15 @@
import asyncio import asyncio
from cognee.shared.logging_utils import get_logger, ERROR
from typing import List, Optional from typing import List, Optional
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
logger = get_logger(level=ERROR) logger = get_logger(level=ERROR)
@ -62,11 +63,14 @@ async def get_memory_fragment(
if properties_to_project is None: if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"] properties_to_project = ["id", "description", "name", "type", "text"]
await memory_fragment.project_graph_from_db( try:
graph_engine, await memory_fragment.project_graph_from_db(
node_properties_to_project=properties_to_project, graph_engine,
edge_properties_to_project=["relationship_name"], node_properties_to_project=properties_to_project,
) edge_properties_to_project=["relationship_name"],
)
except EntityNotFoundError:
pass
return memory_fragment return memory_fragment
@ -139,16 +143,21 @@ async def brute_force_search(
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_text=query, limit=top_k
)
except CollectionNotFoundError:
return []
try: try:
results = await asyncio.gather( results = await asyncio.gather(
*[ *[search_in_collection(collection_name) for collection_name in collections]
vector_engine.get_distance_from_collection_elements(collection, query_text=query)
for collection in collections
]
) )
if all(not item for item in results): if all(not item for item in results):
raise CollectionDistancesNotFoundError() return []
node_distances = {collection: result for collection, result in zip(collections, results)} node_distances = {collection: result for collection, result in zip(collections, results)}
@ -161,6 +170,8 @@ async def brute_force_search(
return results return results
except CollectionNotFoundError:
return []
except Exception as error: except Exception as error:
logger.error( logger.error(
"Error during brute force search for user: %s, query: %s. Error: %s", "Error during brute force search for user: %s, query: %s. Error: %s",

View file

@ -350,11 +350,3 @@ class ChunkSummaries(BaseModel):
"""Relevant summary and chunk id""" """Relevant summary and chunk id"""
summaries: List[ChunkSummary] summaries: List[ChunkSummary]
class MonitoringTool(str, Enum):
"""Monitoring tools"""
LANGFUSE = "langfuse"
LLMLITE = "llmlite"
LANGSMITH = "langsmith"

View file

@ -312,7 +312,7 @@ def setup_logging(log_level=None, name=None):
root_logger.addHandler(file_handler) root_logger.addHandler(file_handler)
root_logger.setLevel(log_level) root_logger.setLevel(log_level)
if log_level > logging.WARNING: if log_level > logging.DEBUG:
import warnings import warnings
from sqlalchemy.exc import SAWarning from sqlalchemy.exc import SAWarning

View file

@ -1,11 +0,0 @@
import os
import pytest
@pytest.fixture(autouse=True, scope="session")
def copy_cognee_db_to_target_location():
os.makedirs("cognee/.cognee_system/databases/", exist_ok=True)
os.system(
"cp cognee/tests/integration/run_toy_tasks/data/cognee_db cognee/.cognee_system/databases/cognee_db"
)

View file

@ -74,19 +74,20 @@ async def main():
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search( # NOTE: Due to the test failing often on weak LLM models we've removed this test for now
query_type=SearchType.NATURAL_LANGUAGE, # search_results = await cognee.search(
query_text=f"Find nodes connected to node with name {random_node_name}", # query_type=SearchType.NATURAL_LANGUAGE,
) # query_text=f"Find nodes connected to node with name {random_node_name}",
assert len(search_results) != 0, "Query related natural language don't exist." # )
print("\nExtracted results are:\n") # assert len(search_results) != 0, "Query related natural language don't exist."
for result in search_results: # print("\nExtracted results are:\n")
print(f"{result}\n") # for result in search_results:
# print(f"{result}\n")
user = await get_default_user() user = await get_default_user()
history = await get_history(user.id) history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -82,7 +82,7 @@ async def main():
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
collections = get_vector_engine().client.collections.list_all() collections = await get_vector_engine().client.collections.list_all()
assert len(collections) == 0, "Weaviate vector database is not empty" assert len(collections) == 0, "Weaviate vector database is not empty"

View file

@ -48,3 +48,7 @@ async def run_and_check_tasks():
def test_run_tasks(): def test_run_tasks():
asyncio.run(run_and_check_tasks()) asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -0,0 +1,47 @@
import asyncio
import cognee
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.infrastructure.databases.relational import create_db_and_tables
async def run_and_check_tasks():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
def task_1(num, context):
return num + context
def task_2(num):
return num * 2
def task_3(num, context):
return num**context
await create_db_and_tables()
user = await get_default_user()
pipeline = run_tasks_base(
[
Task(task_1),
Task(task_2),
Task(task_3),
],
data=5,
user=user,
context=7,
)
final_result = 4586471424
async for result in pipeline:
assert result == final_result
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -16,11 +16,11 @@ class TestChunksRetriever:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chunk_context_simple(self): async def test_chunk_context_simple(self):
system_directory_path = os.path.join( system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context" pathlib.Path(__file__).parent, ".cognee_system/test_chunks_context_simple"
) )
cognee.config.system_root_directory(system_directory_path) cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join( data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context" pathlib.Path(__file__).parent, ".data_storage/test_chunks_context_simple"
) )
cognee.config.data_root_directory(data_directory_path) cognee.config.data_root_directory(data_directory_path)
@ -73,11 +73,11 @@ class TestChunksRetriever:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chunk_context_complex(self): async def test_chunk_context_complex(self):
system_directory_path = os.path.join( system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
) )
cognee.config.system_root_directory(system_directory_path) cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join( data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
) )
cognee.config.data_root_directory(data_directory_path) cognee.config.data_root_directory(data_directory_path)
@ -162,11 +162,11 @@ class TestChunksRetriever:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self): async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join( system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_empty"
) )
cognee.config.system_root_directory(system_directory_path) cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join( data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_empty"
) )
cognee.config.data_root_directory(data_directory_path) cognee.config.data_root_directory(data_directory_path)
@ -190,6 +190,9 @@ if __name__ == "__main__":
test = TestChunksRetriever() test = TestChunksRetriever()
run(test.test_chunk_context_simple()) async def main():
run(test.test_chunk_context_complex()) await test.test_chunk_context_simple()
run(test.test_chunk_context_on_empty_graph()) await test.test_chunk_context_complex()
await test.test_chunk_context_on_empty_graph()
run(main())

View file

@ -154,6 +154,9 @@ if __name__ == "__main__":
test = TestGraphCompletionRetriever() test = TestGraphCompletionRetriever()
run(test.test_graph_completion_context_simple()) async def main():
run(test.test_graph_completion_context_complex()) await test.test_graph_completion_context_simple()
run(test.test_get_graph_completion_context_on_empty_graph()) await test.test_graph_completion_context_complex()
await test.test_get_graph_completion_context_on_empty_graph()
run(main())

View file

@ -127,7 +127,7 @@ class TextSummariesRetriever:
await add_data_points(entities) await add_data_points(entities)
retriever = SummariesRetriever(limit=20) retriever = SummariesRetriever(top_k=20)
context = await retriever.get_context("Christina") context = await retriever.get_context("Christina")

View file

@ -1,44 +0,0 @@
import pytest
from unittest.mock import AsyncMock, patch
from cognee.modules.users.models import User
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_search,
brute_force_triplet_search,
)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
user = User(id="test_user")
query = "test query"
collections = ["nonexistent_collection"]
top_k = 5
mock_memory_fragment = AsyncMock()
mock_vector_engine = AsyncMock()
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine):
user = User(id="test_user")
query = "test query"
collections = ["nonexistent_collection"]
top_k = 5
mock_memory_fragment = AsyncMock()
mock_vector_engine = AsyncMock()
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_triplet_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)

162
community/README.zh.md Normal file
View file

@ -0,0 +1,162 @@
<div align="center">
<a href="https://github.com/topoteretes/cognee">
<img src="https://raw.githubusercontent.com/topoteretes/cognee/refs/heads/dev/assets/cognee-logo-transparent.png" alt="Cognee Logo" height="60">
</a>
<br />
cognee - AI应用和智能体的记忆层
<p align="center">
<a href="https://www.youtube.com/watch?v=1bezuvLwJmw&t=2s">演示</a>
.
<a href="https://cognee.ai">了解更多</a>
·
<a href="https://discord.gg/NQPKmU5CCg">加入Discord</a>
</p>
[![GitHub forks](https://img.shields.io/github/forks/topoteretes/cognee.svg?style=social&label=Fork&maxAge=2592000)](https://GitHub.com/topoteretes/cognee/network/)
[![GitHub stars](https://img.shields.io/github/stars/topoteretes/cognee.svg?style=social&label=Star&maxAge=2592000)](https://GitHub.com/topoteretes/cognee/stargazers/)
[![GitHub commits](https://badgen.net/github/commits/topoteretes/cognee)](https://GitHub.com/topoteretes/cognee/commit/)
[![Github tag](https://badgen.net/github/tag/topoteretes/cognee)](https://github.com/topoteretes/cognee/tags/)
[![Downloads](https://static.pepy.tech/badge/cognee)](https://pepy.tech/project/cognee)
[![License](https://img.shields.io/github/license/topoteretes/cognee?colorA=00C586&colorB=000000)](https://github.com/topoteretes/cognee/blob/main/LICENSE)
[![Contributors](https://img.shields.io/github/contributors/topoteretes/cognee?colorA=00C586&colorB=000000)](https://github.com/topoteretes/cognee/graphs/contributors)
可靠的AI智能体响应。
使用可扩展、模块化的ECL提取、认知、加载管道构建动态智能体记忆。
更多[使用场景](https://docs.cognee.ai/use_cases)。
<div style="text-align: center">
<img src="cognee_benefits_zh.JPG" alt="为什么选择cognee" width="100%" />
</div>
</div>
## 功能特性
- 互联并检索您的历史对话、文档、图像和音频转录
- 减少幻觉、开发人员工作量和成本
- 仅使用Pydantic将数据加载到图形和向量数据库
- 从30多个数据源摄取数据时进行数据操作
## 开始使用
通过Google Colab <a href="https://colab.research.google.com/drive/1g-Qnx6l_ecHZi0IOw23rg0qC4TYvEvWZ?usp=sharing">笔记本</a><a href="https://github.com/topoteretes/cognee-starter">入门项目</a>快速上手
## 贡献
您的贡献是使这成为真正开源项目的核心。我们**非常感谢**任何贡献。更多信息请参阅[`CONTRIBUTING.md`](CONTRIBUTING.md)。
## 📦 安装
您可以使用**pip**、**poetry**、**uv**或任何其他Python包管理器安装Cognee。
### 使用pip
```bash
pip install cognee
```
## 💻 基本用法
### 设置
```
import os
os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
```
您也可以通过创建.env文件设置变量使用我们的<a href="https://github.com/topoteretes/cognee/blob/main/.env.template">模板</a>
要使用不同的LLM提供商请查看我们的<a href="https://docs.cognee.ai">文档</a>获取更多信息。
### 简单示例
此脚本将运行默认管道:
```python
import cognee
import asyncio
async def main():
# Add text to cognee
await cognee.add("自然语言处理NLP是计算机科学和信息检索的跨学科领域。")
# Generate the knowledge graph
await cognee.cognify()
# Query the knowledge graph
results = await cognee.search("告诉我关于NLP")
# Display the results
for result in results:
print(result)
if __name__ == '__main__':
asyncio.run(main())
```
示例输出:
```
自然语言处理NLP是计算机科学和信息检索的跨学科领域。它关注计算机和人类语言之间的交互使机器能够理解和处理自然语言。
```
图形可视化:
<a href="https://rawcdn.githack.com/topoteretes/cognee/refs/heads/main/assets/graph_visualization.html"><img src="https://rawcdn.githack.com/topoteretes/cognee/refs/heads/main/assets/graph_visualization.png" width="100%" alt="图形可视化"></a>
在[浏览器](https://rawcdn.githack.com/topoteretes/cognee/refs/heads/main/assets/graph_visualization.html)中打开。
有关更高级的用法,请查看我们的<a href="https://docs.cognee.ai">文档</a>
## 了解我们的架构
<div style="text-align: center">
<img src="cognee_diagram_zh.JPG" alt="cognee概念图" width="100%" />
</div>
## 演示
1. 什么是AI记忆
[了解cognee](https://github.com/user-attachments/assets/8b2a0050-5ec4-424c-b417-8269971503f0)
2. 简单GraphRAG演示
[简单GraphRAG演示](https://github.com/user-attachments/assets/f57fd9ea-1dc0-4904-86eb-de78519fdc32)
3. cognee与Ollama
[cognee与本地模型](https://github.com/user-attachments/assets/834baf9a-c371-4ecf-92dd-e144bd0eb3f6)
## 行为准则
我们致力于为我们的社区提供愉快和尊重的开源体验。有关更多信息,请参阅<a href="https://github.com/topoteretes/cognee/blob/main/CODE_OF_CONDUCT.md"><code>CODE_OF_CONDUCT</code></a>
## 💫 贡献者
<a href="https://github.com/topoteretes/cognee/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=topoteretes/cognee"/>
</a>
## Star历史
[![Star History Chart](https://api.star-history.com/svg?repos=topoteretes/cognee&type=Date)](https://star-history.com/#topoteretes/cognee&Date)

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 181 KiB

View file

@ -0,0 +1,37 @@
text_1 = """
1. Audi
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
2. BMW
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
3. Mercedes-Benz
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
4. Porsche
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
5. Volkswagen
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
"""
text_2 = """
1. Apple
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
2. Google
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
3. Microsoft
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
4. Amazon
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
5. Meta
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
"""

1725
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,14 @@
[tool.poetry] [project]
name = "cognee" name = "cognee"
version = "0.1.39" version = "0.1.39"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = ["Vasilije Markovic", "Boris Arzentar"] authors = [
{ name = "Vasilije Markovic" },
{ name = "Boris Arzentar" },
]
requires-python = ">=3.10,<=3.13"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"
homepage = "https://www.cognee.ai"
repository = "https://github.com/topoteretes/cognee"
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
@ -14,130 +16,131 @@ classifiers = [
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Operating System :: MacOS :: MacOS X", "Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows" "Operating System :: Microsoft :: Windows",
]
dependencies = [
"openai>=1.59.4,<2",
"python-dotenv==1.0.1",
"pydantic==2.10.5",
"pydantic-settings>=2.2.1,<3",
"typing_extensions==4.12.2",
"nltk==3.9.1",
"numpy>=1.26.4, <=2.1",
"pandas==2.2.3",
# Note: New s3fs and boto3 versions don't work well together
# Always use comaptible fixed versions of these two dependencies
"s3fs==2025.3.2",
"boto3==1.37.1",
"botocore>=1.35.54,<2",
"sqlalchemy==2.0.39",
"aiosqlite>=0.20.0,<0.21",
"tiktoken<=0.9.0",
"litellm>=1.57.4",
"instructor==1.7.2",
"langfuse>=2.32.0,<3",
"filetype>=1.2.0,<2",
"aiohttp>=3.11.14,<4",
"aiofiles>=23.2.1,<24",
"owlready2>=0.47,<0.48",
"graphistry>=0.33.5,<0.34",
"pypdf>=4.1.0,<6.0.0",
"jinja2>=3.1.3,<4",
"matplotlib>=3.8.3,<4",
"networkx>=3.2.1,<4",
"lancedb==0.16.0",
"alembic>=1.13.3,<2",
"pre-commit>=4.0.1,<5",
"scikit-learn>=1.6.1,<2",
"limits>=4.4.1,<5",
"fastapi==0.115.7",
"python-multipart==0.0.20",
"fastapi-users[sqlalchemy]==14.0.1",
"dlt[sqlalchemy]>=1.9.0,<2",
"sentry-sdk[fastapi]>=2.9.0,<3",
"structlog>=25.2.0,<26",
] ]
[tool.poetry.dependencies] [project.optional-dependencies]
python = ">=3.10,<=3.13" api = [
openai = "^1.59.4" "uvicorn==0.34.0",
python-dotenv = "1.0.1" "gunicorn>=20.1.0,<21",
pydantic = "2.10.5" ]
pydantic-settings = "^2.2.1" distributed = [
typing_extensions = "4.12.2" "modal==0.74.15",
nltk = "3.9.1" ]
numpy = ">=1.26.4, <=2.1" weaviate = ["weaviate-client==4.9.6"]
pandas = "2.2.3" qdrant = ["qdrant-client>=1.9.0,<2"]
boto3 = "^1.26.125" neo4j = ["neo4j>=5.20.0,<6"]
botocore="^1.35.54" postgres = [
sqlalchemy = "2.0.39" "psycopg2>=2.9.10,<3",
aiosqlite = "^0.20.0" "pgvector>=0.3.5,<0.4",
tiktoken = "<=0.9.0" "asyncpg==0.30.0",
litellm = ">=1.57.4" ]
instructor = "1.7.2" notebook = ["notebook>=7.1.0,<8"]
langfuse = "^2.32.0" langchain = [
filetype = "^1.2.0" "langsmith==0.2.3",
aiohttp = "^3.11.14" "langchain_text_splitters==0.3.2",
aiofiles = "^23.2.1" ]
owlready2 = "^0.47" llama-index = ["llama-index-core>=0.12.11,<0.13"]
graphistry = "^0.33.5" gemini = ["google-generativeai>=0.8.4,<0.9"]
pypdf = ">=4.1.0,<6.0.0" huggingface = ["transformers>=4.46.3,<5"]
jinja2 = "^3.1.3" ollama = ["transformers>=4.46.3,<5"]
matplotlib = "^3.8.3" mistral = ["mistral-common>=1.5.2,<2"]
networkx = "^3.2.1" anthropic = ["anthropic>=0.26.1,<0.27"]
lancedb = "0.16.0" deepeval = ["deepeval>=2.0.1,<3"]
alembic = "^1.13.3" posthog = ["posthog>=3.5.0,<4"]
pre-commit = "^4.0.1" falkordb = ["falkordb==1.0.9"]
scikit-learn = "^1.6.1" kuzu = ["kuzu==0.8.2"]
limits = "^4.4.1" groq = ["groq==0.8.0"]
fastapi = {version = "0.115.7"} milvus = ["pymilvus>=2.5.0,<3"]
python-multipart = "0.0.20" chromadb = [
fastapi-users = {version = "14.0.1", extras = ["sqlalchemy"]} "chromadb>=0.3.0,<0.7",
uvicorn = {version = "0.34.0", optional = true} "pypika==0.48.8",
gunicorn = {version = "^20.1.0", optional = true} ]
dlt = {extras = ["sqlalchemy"], version = "^1.9.0"} docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.16.13,<0.17"]
qdrant-client = {version = "^1.9.0", optional = true} codegraph = [
weaviate-client = {version = "4.9.6", optional = true} "fastembed<=0.6.0 ; python_version < '3.13'",
neo4j = {version = "^5.20.0", optional = true} "transformers>=4.46.3,<5",
falkordb = {version = "1.0.9", optional = true} "tree-sitter>=0.24.0,<0.25",
kuzu = {version = "0.8.2", optional = true} "tree-sitter-python>=0.23.6,<0.24",
chromadb = {version = "^0.6.0", optional = true} ]
langchain_text_splitters = {version = "0.3.2", optional = true} evals = [
langsmith = {version = "0.2.3", optional = true} "plotly>=6.0.0,<7",
posthog = {version = "^3.5.0", optional = true} "gdown>=5.2.0,<6",
groq = {version = "0.8.0", optional = true} ]
anthropic = {version = "^0.26.1", optional = true} gui = [
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} "pyside6>=6.8.3,<7",
asyncpg = {version = "0.30.0", optional = true} "qasync>=0.27.1,<0.28",
pgvector = {version = "^0.3.5", optional = true} ]
psycopg2 = {version = "^2.9.10", optional = true} graphiti = ["graphiti-core>=0.7.0,<0.8"]
llama-index-core = {version = "^0.12.11", optional = true} dev = [
deepeval = {version = "^2.0.1", optional = true} "pytest>=7.4.0,<8",
transformers = {version = "^4.46.3", optional = true} "pytest-cov>=6.1.1",
pymilvus = {version = "^2.5.0", optional = true} "pytest-asyncio>=0.21.1,<0.22",
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.13", optional = true } "coverage>=7.3.2,<8",
mistral-common = {version = "^1.5.2", optional = true} "mypy>=1.7.1,<2",
fastembed = {version = "<=0.6.0", optional = true, markers = "python_version < '3.13'"} "notebook>=7.1.0,<8",
tree-sitter = {version = "^0.24.0", optional = true} "deptry>=0.20.0,<0.21",
tree-sitter-python = {version = "^0.23.6", optional = true} "debugpy==1.8.9",
plotly = {version = "^6.0.0", optional = true} "pylint>=3.0.3,<4",
gdown = {version = "^5.2.0", optional = true} "ruff>=0.9.2,<1.0.0",
qasync = {version = "^0.27.1", optional = true} "tweepy==4.14.0",
graphiti-core = {version = "^0.7.0", optional = true} "gitpython>=3.1.43,<4",
structlog = "^25.2.0" "pylance==0.19.2",
pyside6 = {version = "^6.8.3", optional = true} "mkdocs-material>=9.5.42,<10",
google-generativeai = {version = "^0.8.4", optional = true} "mkdocs-minify-plugin>=0.8.0,<0.9",
notebook = {version = "^7.1.0", optional = true} "mkdocstrings[python]>=0.26.2,<0.27",
s3fs = "^2025.3.2" ]
modal = "^0.74.15"
[project.urls]
Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[tool.poetry.extras] [build-system]
api = ["uvicorn", "gunicorn"] requires = ["hatchling"]
weaviate = ["weaviate-client"] build-backend = "hatchling.build"
qdrant = ["qdrant-client"]
neo4j = ["neo4j"]
postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
langchain = ["langsmith", "langchain_text_splitters"]
llama-index = ["llama-index-core"]
gemini = ["google-generativeai"]
huggingface = ["transformers"]
ollama = ["transformers"]
mistral = ["mistral-common"]
anthropic = ["anthropic"]
deepeval = ["deepeval"]
posthog = ["posthog"]
falkordb = ["falkordb"]
kuzu = ["kuzu"]
groq = ["groq"]
milvus = ["pymilvus"]
chromadb = ["chromadb"]
docs = ["unstructured"]
codegraph = ["fastembed", "transformers", "tree-sitter", "tree-sitter-python"]
evals = ["plotly", "gdown"]
gui = ["pyside6", "qasync"]
graphiti = ["graphiti-core"]
[tool.poetry.group.dev.dependencies] [tool.ruff]
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
coverage = "^7.3.2"
mypy = "^1.7.1"
notebook = {version = "^7.1.0", optional = true}
deptry = "^0.20.0"
debugpy = "1.8.9"
pylint = "^3.0.3"
ruff = ">=0.9.2,<1.0.0"
tweepy = "4.14.0"
gitpython = "^3.1.43"
pylance = "0.19.2"
[tool.poetry.group.docs.dependencies]
mkdocs-material = "^9.5.42"
mkdocs-minify-plugin = "^0.8.0"
mkdocstrings = {extras = ["python"], version = "^0.26.2"}
[tool.ruff] # https://beta.ruff.rs/docs/
line-length = 100 line-length = 100
exclude = [ exclude = [
"migrations/", # Ignore migrations directory "migrations/", # Ignore migrations directory
@ -152,7 +155,3 @@ exclude = [
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["F401"] ignore = ["F401"]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

8607
uv.lock generated Normal file

File diff suppressed because it is too large Load diff