Merge remote-tracking branch 'origin/dev'

This commit is contained in:
Boris Arzentar 2025-02-05 17:53:50 +01:00
commit 65d51d4aa7
54 changed files with 506 additions and 384 deletions

View file

@ -130,7 +130,7 @@ This script will run the default pipeline:
```python
import cognee
import asyncio
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
async def main():
# Create a clean slate for cognee -- reset data and system state

View file

@ -6,8 +6,8 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"cognee",
"mcp==1.2.0",
"cognee[codegraph]",
"mcp==1.1.3",
]
[[project.authors]]

View file

@ -1,3 +1,4 @@
import json
import os
import cognee
import logging
@ -8,8 +9,10 @@ from contextlib import redirect_stderr, redirect_stdout
import mcp.types as types
from mcp.server import Server, NotificationOptions
from mcp.server.models import InitializationOptions
from cognee.api.v1.search import SearchType
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder
mcp = Server("cognee")
@ -41,6 +44,19 @@ async def list_tools() -> list[types.Tool]:
"required": ["text"],
},
),
types.Tool(
name="codify",
description="Transforms codebase into knowledge graph",
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
},
},
"required": ["repo_path"],
},
),
types.Tool(
name="search",
description="Searches for information in knowledge graph",
@ -51,6 +67,10 @@ async def list_tools() -> list[types.Tool]:
"type": "string",
"description": "The query to search for",
},
"search_type": {
"type": "string",
"description": "The type of search to perform (e.g., INSIGHTS, CODE)",
},
},
"required": ["search_query"],
},
@ -72,15 +92,21 @@ async def call_tools(name: str, arguments: dict) -> list[types.TextContent]:
with open(os.devnull, "w") as fnull:
with redirect_stdout(fnull), redirect_stderr(fnull):
if name == "cognify":
await cognify(
cognify(
text=arguments["text"],
graph_model_file=arguments.get("graph_model_file", None),
graph_model_name=arguments.get("graph_model_name", None),
)
return [types.TextContent(type="text", text="Ingested")]
if name == "codify":
await codify(arguments.get("repo_path"))
return [types.TextContent(type="text", text="Indexed")]
elif name == "search":
search_results = await search(arguments["search_query"])
search_results = await search(
arguments["search_query"], arguments["search_type"]
)
return [types.TextContent(type="text", text=search_results)]
elif name == "prune":
@ -102,21 +128,30 @@ async def cognify(text: str, graph_model_file: str = None, graph_model_name: str
await cognee.add(text)
try:
await cognee.cognify(graph_model=graph_model)
asyncio.create_task(cognee.cognify(graph_model=graph_model))
except Exception as e:
raise ValueError(f"Failed to cognify: {str(e)}")
async def search(search_query: str) -> str:
async def codify(repo_path: str):
async for result in run_code_graph_pipeline(repo_path, False):
print(result)
async def search(search_query: str, search_type: str) -> str:
"""Search the knowledge graph"""
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
)
results = retrieved_edges_to_string(search_results)
return results
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
else:
results = retrieved_edges_to_string(search_results)
return results
async def prune() -> str:
async def prune():
"""Reset the knowledge graph"""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

90
cognee-mcp/uv.lock generated
View file

@ -474,7 +474,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.1.22"
version = "0.1.23"
source = { directory = "../" }
dependencies = [
{ name = "aiofiles" },
@ -516,11 +516,16 @@ dependencies = [
{ name = "sqlalchemy" },
{ name = "tenacity" },
{ name = "tiktoken" },
{ name = "transformers" },
{ name = "typing-extensions" },
{ name = "uvicorn" },
]
[package.optional-dependencies]
codegraph = [
{ name = "jedi" },
{ name = "parso" },
]
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=23.2.1,<24.0.0" },
@ -539,6 +544,7 @@ requires-dist = [
{ name = "fastapi", specifier = "==0.115.7" },
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" },
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
{ name = "google-generativeai", marker = "extra == 'gemini'", specifier = ">=0.8.4,<0.9.0" },
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
{ name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" },
{ name = "gunicorn", specifier = ">=20.1.0,<21.0.0" },
@ -578,7 +584,7 @@ requires-dist = [
{ name = "sqlalchemy", specifier = "==2.0.36" },
{ name = "tenacity", specifier = ">=9.0.0,<10.0.0" },
{ name = "tiktoken", specifier = "==0.7.0" },
{ name = "transformers", specifier = ">=4.46.3,<5.0.0" },
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5.0.0" },
{ name = "typing-extensions", specifier = "==4.12.2" },
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" },
{ name = "uvicorn", specifier = "==0.34.0" },
@ -590,14 +596,14 @@ name = "cognee-mcp"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "cognee" },
{ name = "cognee", extra = ["codegraph"] },
{ name = "mcp" },
]
[package.metadata]
requires-dist = [
{ name = "cognee", directory = "../" },
{ name = "mcp", specifier = "==1.2.0" },
{ name = "cognee", extras = ["codegraph"], directory = "../" },
{ name = "mcp", specifier = "==1.1.3" },
]
[[package]]
@ -1307,6 +1313,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/82/fd319382c1a33d7021cf151007b4cbd5daddf09d9ca5fb670e476668f9fc/instructor-1.7.2-py3-none-any.whl", hash = "sha256:cb43d27f6d7631c31762b936b2fcb44d2a3f9d8a020430a0f4d3484604ffb95b", size = 71353 },
]
[[package]]
name = "jedi"
version = "0.19.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "parso" },
]
sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 },
]
[[package]]
name = "jinja2"
version = "3.1.5"
@ -1739,21 +1757,19 @@ wheels = [
[[package]]
name = "mcp"
version = "1.2.0"
version = "1.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "httpx" },
{ name = "httpx-sse" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "sse-starlette" },
{ name = "starlette" },
{ name = "uvicorn" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ab/a5/b08dc846ebedae9f17ced878e6975826e90e448cd4592f532f6a88a925a7/mcp-1.2.0.tar.gz", hash = "sha256:2b06c7ece98d6ea9e6379caa38d74b432385c338fb530cb82e2c70ea7add94f5", size = 102973 }
sdist = { url = "https://files.pythonhosted.org/packages/f7/60/66ebfd280b197f9a9d074c9e46cb1ac3186a32d12e6bd0425c24fe7cf7e8/mcp-1.1.3.tar.gz", hash = "sha256:af11018b8e9153cdd25f3722ec639fe7a462c00213a330fd6f593968341a9883", size = 57903 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/af/84/fca78f19ac8ce6c53ba416247c71baa53a9e791e98d3c81edbc20a77d6d1/mcp-1.2.0-py3-none-any.whl", hash = "sha256:1d0e77d8c14955a5aea1f5aa1f444c8e531c09355c829b20e42f7a142bc0755f", size = 66468 },
{ url = "https://files.pythonhosted.org/packages/b8/08/cfcfa13e41f8d27503c51a8cbf1939d720073ace92469d08655bb5de1b24/mcp-1.1.3-py3-none-any.whl", hash = "sha256:71462d6cd7c06c14689dfcf110ff22286ba1b608cfc3515c0a5cbe33d131731a", size = 36997 },
]
[[package]]
@ -2083,6 +2099,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 },
]
[[package]]
name = "parso"
version = "0.8.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
]
[[package]]
name = "pathvalidate"
version = "3.2.3"
@ -2850,28 +2875,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5f/ce/22673f4a85ccc640735b4f8d12178a0f41b5d3c6eda7f33756d10ce56901/s3transfer-0.11.1-py3-none-any.whl", hash = "sha256:8fa0aa48177be1f3425176dfe1ab85dcd3d962df603c3dbfc585e6bf857ef0ff", size = 84111 },
]
[[package]]
name = "safetensors"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f4/4f/2ef9ef1766f8c194b01b67a63a444d2e557c8fe1d82faf3ebd85f370a917/safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8", size = 66957 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/96/d1/017e31e75e274492a11a456a9e7c171f8f7911fe50735b4ec6ff37221220/safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2", size = 427067 },
{ url = "https://files.pythonhosted.org/packages/24/84/e9d3ff57ae50dd0028f301c9ee064e5087fe8b00e55696677a0413c377a7/safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae", size = 408856 },
{ url = "https://files.pythonhosted.org/packages/f1/1d/fe95f5dd73db16757b11915e8a5106337663182d0381811c81993e0014a9/safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c", size = 450088 },
{ url = "https://files.pythonhosted.org/packages/cf/21/e527961b12d5ab528c6e47b92d5f57f33563c28a972750b238b871924e49/safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e", size = 458966 },
{ url = "https://files.pythonhosted.org/packages/a5/8b/1a037d7a57f86837c0b41905040369aea7d8ca1ec4b2a77592372b2ec380/safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869", size = 509915 },
{ url = "https://files.pythonhosted.org/packages/61/3d/03dd5cfd33839df0ee3f4581a20bd09c40246d169c0e4518f20b21d5f077/safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a", size = 527664 },
{ url = "https://files.pythonhosted.org/packages/c5/dc/8952caafa9a10a3c0f40fa86bacf3190ae7f55fa5eef87415b97b29cb97f/safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5", size = 461978 },
{ url = "https://files.pythonhosted.org/packages/60/da/82de1fcf1194e3dbefd4faa92dc98b33c06bed5d67890e0962dd98e18287/safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975", size = 491253 },
{ url = "https://files.pythonhosted.org/packages/5a/9a/d90e273c25f90c3ba1b0196a972003786f04c39e302fbd6649325b1272bb/safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e", size = 628644 },
{ url = "https://files.pythonhosted.org/packages/70/3c/acb23e05aa34b4f5edd2e7f393f8e6480fbccd10601ab42cd03a57d4ab5f/safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f", size = 721648 },
{ url = "https://files.pythonhosted.org/packages/71/45/eaa3dba5253a7c6931230dc961641455710ab231f8a89cb3c4c2af70f8c8/safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf", size = 659588 },
{ url = "https://files.pythonhosted.org/packages/b0/71/2f9851164f821064d43b481ddbea0149c2d676c4f4e077b178e7eeaa6660/safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76", size = 632533 },
{ url = "https://files.pythonhosted.org/packages/00/f1/5680e2ef61d9c61454fad82c344f0e40b8741a9dbd1e31484f0d31a9b1c3/safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2", size = 291167 },
{ url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
]
[[package]]
name = "scikit-learn"
version = "1.6.1"
@ -3348,27 +3351,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
]
[[package]]
name = "transformers"
version = "4.48.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "regex" },
{ name = "requests" },
{ name = "safetensors" },
{ name = "tokenizers" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/21/6b/caf620fae7fbf35947c81e7dd0834493b9ad9b71bb9e433025ac7a07e79a/transformers-4.48.1.tar.gz", hash = "sha256:7c1931facc3ee8adcbf86fc7a87461d54c1e40eca3bb57fef1ee9f3ecd32187e", size = 8365872 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl", hash = "sha256:24be0564b0a36d9e433d9a65de248f1545b6f6edce1737669605eb6a8141bbbb", size = 9665001 },
]
[[package]]
name = "typer"
version = "0.15.1"

View file

@ -188,6 +188,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
except Exception as e:
logger.exception(f"Failed to start server: {e}")
# Here you could add any cleanup code or error recovery code.
raise e
if __name__ == "__main__":

View file

@ -16,9 +16,22 @@ async def add(
dataset_name: str = "main_dataset",
user: User = None,
):
# Create tables for databases
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
# Initialize first_run attribute if it doesn't exist
if not hasattr(add, "first_run"):
add.first_run = True
if add.first_run:
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
# Test LLM and Embedding configuration once before running Cognee
await test_llm_connection()
await test_embedding_connection()
add.first_run = False # Update flag after first run
if user is None:
user = await get_default_user()

View file

@ -25,7 +25,7 @@ from cognee.tasks.documents import (
)
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.descriptive_metrics import store_descriptive_metrics
from cognee.modules.data.methods import store_descriptive_metrics
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.summarization import summarize_text
@ -165,7 +165,7 @@ async def get_default_tasks(
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(store_descriptive_metrics),
Task(store_descriptive_metrics, include_optional=True),
]
except Exception as error:
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)

View file

@ -1,5 +1,5 @@
from fastapi import APIRouter
from pydantic import BaseModel
from cognee.api.DTO import InDTO
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
@ -7,14 +7,14 @@ from cognee.modules.retrieval.description_to_codepart_search import (
from fastapi.responses import JSONResponse
class CodePipelineIndexPayloadDTO(BaseModel):
class CodePipelineIndexPayloadDTO(InDTO):
repo_path: str
include_docs: bool = False
class CodePipelineRetrievePayloadDTO(BaseModel):
class CodePipelineRetrievePayloadDTO(InDTO):
query: str
fullInput: str
full_input: str
def get_code_pipeline_router() -> APIRouter:
@ -34,9 +34,9 @@ def get_code_pipeline_router() -> APIRouter:
"""This endpoint is responsible for retrieving the context."""
try:
query = (
payload.fullInput.replace("cognee ", "")
if payload.fullInput.startswith("cognee ")
else payload.fullInput
payload.full_input.replace("cognee ", "")
if payload.full_input.startswith("cognee ")
else payload.full_input
)
retrieved_codeparts, __ = await code_description_to_code_part_search(
@ -45,8 +45,8 @@ def get_code_pipeline_router() -> APIRouter:
return [
{
"name": codepart.attributes["id"],
"description": codepart.attributes["id"],
"name": codepart.attributes["file_path"],
"description": codepart.attributes["file_path"],
"content": codepart.attributes["source_code"],
}
for codepart in retrieved_codeparts

View file

@ -2,7 +2,7 @@ from uuid import UUID
from datetime import datetime
from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history

View file

@ -1,100 +0,0 @@
"""This module contains the search function that is used to search for nodes in the graph."""
import asyncio
from enum import Enum
from typing import Dict, Any, Callable, List
from pydantic import BaseModel, field_validator
from cognee.modules.search.graph import search_cypher
from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.exceptions import UserNotFoundError
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
class SearchType(Enum):
ADJACENT = "ADJACENT"
TRAVERSE = "TRAVERSE"
SIMILARITY = "SIMILARITY"
SUMMARY = "SUMMARY"
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",)
CYPHER = "CYPHER"
@staticmethod
def from_str(name: str):
try:
return SearchType[name.upper()]
except KeyError as error:
raise ValueError(f"{name} is not a valid SearchType") from error
class SearchParameters(BaseModel):
search_type: SearchType
params: Dict[str, Any]
@field_validator("search_type", mode="before")
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
if isinstance(value, str):
return SearchType.from_str(value)
return value
async def search(search_type: str, params: Dict[str, Any], user: User = None) -> List:
if user is None:
user = await get_default_user()
if user is None:
raise UserNotFoundError
own_document_ids = await get_document_ids_for_user(user.id)
search_params = SearchParameters(search_type=search_type, params=params)
search_results = await specific_search([search_params], user)
from uuid import UUID
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
return filtered_search_results
async def specific_search(query_params: List[SearchParameters], user) -> List:
search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent,
SearchType.SUMMARY: search_summary,
SearchType.CYPHER: search_cypher,
SearchType.TRAVERSE: search_traverse,
SearchType.SIMILARITY: search_similarity,
}
search_tasks = []
send_telemetry("cognee.search EXECUTION STARTED", user.id)
for search_param in query_params:
search_func = search_functions.get(search_param.search_type)
if search_func:
# Schedule the coroutine for execution and store the task
task = search_func(**search_param.params)
search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently
search_results = await asyncio.gather(*search_tasks)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return search_results[0] if len(search_results) == 1 else search_results

View file

@ -1,29 +1,10 @@
import json
from uuid import UUID
from enum import Enum
from typing import Callable, Dict, Union
from typing import Union
from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.storage.utils import JSONEncoder
from cognee.shared.utils import send_telemetry
from cognee.modules.search.types import SearchType
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
from cognee.modules.search.methods import search as search_function
async def search(
@ -42,43 +23,6 @@ async def search(
if user is None:
raise UserNotFoundError
query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user)
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
filtered_search_results = await search_function(query_text, query_type, datasets, user)
return filtered_search_results
async def specific_search(query_type: SearchType, query: str, user) -> list:
search_tasks: Dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion,
}
search_task = search_tasks.get(query_type)
if search_task is None:
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results

View file

@ -25,11 +25,24 @@ class GraphConfig(BaseSettings):
return {
"graph_filename": self.graph_filename,
"graph_database_provider": self.graph_database_provider,
"graph_file_path": self.graph_file_path,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
}
def to_hashable_dict(self) -> dict:
return {
"graph_database_provider": self.graph_database_provider,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
}

View file

@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface
async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type."""
graph_client = create_graph_engine()
config = get_graph_config()
graph_client = create_graph_engine(**get_graph_config().to_hashable_dict())
# Async functions can't be cached. After creating and caching the graph engine
# handle all necessary async operations for different graph types bellow.
config = get_graph_config()
# Handle loading of graph for NetworkX
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
await graph_client.load_graph_from_file()
@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface:
@lru_cache
def create_graph_engine() -> GraphDBInterface:
def create_graph_engine(
graph_database_provider,
graph_database_url,
graph_database_username,
graph_database_password,
graph_database_port,
graph_file_path,
):
"""Factory function to create the appropriate graph client based on the graph type."""
config = get_graph_config()
if config.graph_database_provider == "neo4j":
if not (
config.graph_database_url
and config.graph_database_username
and config.graph_database_password
):
if graph_database_provider == "neo4j":
if not (graph_database_url and graph_database_username and graph_database_password):
raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter(
graph_database_url=config.graph_database_url,
graph_database_username=config.graph_database_username,
graph_database_password=config.graph_database_password,
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
)
elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_port):
elif graph_database_provider == "falkordb":
if not (graph_database_url and graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface:
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url=config.graph_database_url,
database_port=config.graph_database_port,
database_url=graph_database_url,
database_port=graph_database_port,
embedding_engine=embedding_engine,
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename=config.graph_file_path)
graph_client = NetworkXAdapter(filename=graph_file_path)
return graph_client

View file

@ -54,3 +54,7 @@ class GraphDBInterface(Protocol):
@abstractmethod
async def get_graph_data(self):
raise NotImplementedError
@abstractmethod
async def get_graph_metrics(self, include_optional):
raise NotImplementedError

View file

@ -530,3 +530,17 @@ class Neo4jAdapter(GraphDBInterface):
]
return (nodes, edges)
async def get_graph_metrics(self, include_optional=False):
return {
"num_nodes": -1,
"num_edges": -1,
"mean_degree": -1,
"edge_density": -1,
"num_connected_components": -1,
"sizes_of_connected_components": -1,
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}

View file

@ -14,8 +14,9 @@ import networkx as nx
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
import numpy as np
logger = logging.getLogger("NetworkXAdapter")
logger = logging.getLogger(__name__)
class NetworkXAdapter(GraphDBInterface):
@ -269,8 +270,8 @@ class NetworkXAdapter(GraphDBInterface):
if not isinstance(node["id"], UUID):
node["id"] = UUID(node["id"])
except Exception as e:
print(e)
pass
logger.error(e)
raise e
if isinstance(node.get("updated_at"), int):
node["updated_at"] = datetime.fromtimestamp(
@ -298,8 +299,8 @@ class NetworkXAdapter(GraphDBInterface):
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except Exception as e:
print(e)
pass
logger.error(e)
raise e
if isinstance(edge["updated_at"], int): # Handle timestamp in milliseconds
edge["updated_at"] = datetime.fromtimestamp(
@ -327,8 +328,9 @@ class NetworkXAdapter(GraphDBInterface):
await self.save_graph_to_file(file_path)
except Exception:
except Exception as e:
logger.error("Failed to load graph from file: %s", file_path)
raise e
async def delete_graph(self, file_path: str = None):
"""Asynchronously delete the graph file from the filesystem."""
@ -344,6 +346,7 @@ class NetworkXAdapter(GraphDBInterface):
logger.info("Graph deleted successfully.")
except Exception as error:
logger.error("Failed to delete graph: %s", error)
raise error
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
@ -385,3 +388,64 @@ class NetworkXAdapter(GraphDBInterface):
]
return filtered_nodes, filtered_edges
async def get_graph_metrics(self, include_optional=False):
graph = self.graph
def _get_mean_degree(graph):
degrees = [d for _, d in graph.degree()]
return np.mean(degrees) if degrees else 0
def _get_edge_density(graph):
num_nodes = graph.number_of_nodes()
num_edges = graph.number_of_edges()
num_possible_edges = num_nodes * (num_nodes - 1)
edge_density = num_edges / num_possible_edges if num_possible_edges > 0 else 0
return edge_density
def _get_diameter(graph):
if nx.is_strongly_connected(graph):
return nx.diameter(graph.to_undirected())
else:
return None
def _get_avg_shortest_path_length(graph):
if nx.is_strongly_connected(graph):
return nx.average_shortest_path_length(graph)
else:
return None
def _get_avg_clustering(graph):
try:
return nx.average_clustering(nx.DiGraph(graph))
except Exception as e:
logger.warning("Failed to calculate clustering coefficient: %s", e)
return None
mandatory_metrics = {
"num_nodes": graph.number_of_nodes(),
"num_edges": graph.number_of_edges(),
"mean_degree": _get_mean_degree(graph),
"edge_density": _get_edge_density(graph),
"num_connected_components": nx.number_weakly_connected_components(graph),
"sizes_of_connected_components": [
len(c) for c in nx.weakly_connected_components(graph)
],
}
if include_optional:
optional_metrics = {
"num_selfloops": sum(1 for u, v in graph.edges() if u == v),
"diameter": _get_diameter(graph),
"avg_shortest_path_length": _get_avg_shortest_path_length(graph),
"avg_clustering": _get_avg_clustering(graph),
}
else:
optional_metrics = {
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}
return mandatory_metrics | optional_metrics

View file

@ -1,6 +1,8 @@
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from functools import lru_cache
@lru_cache
def create_relational_engine(
db_path: str,
db_name: str,

View file

@ -1,10 +1,7 @@
# from functools import lru_cache
from .config import get_relational_config
from .create_relational_engine import create_relational_engine
# @lru_cache
def get_relational_engine():
relational_config = get_relational_config()

View file

@ -303,9 +303,10 @@ class SQLAlchemyAdapter:
await connection.execute(text("DROP TABLE IF EXISTS group_permission CASCADE"))
await connection.execute(text("DROP TABLE IF EXISTS permissions CASCADE"))
# Add more DROP TABLE statements for other tables as needed
print("Database tables dropped successfully.")
logger.debug("Database tables dropped successfully.")
except Exception as e:
print(f"Error dropping database tables: {e}")
logger.error(f"Error dropping database tables: {e}")
raise e
async def create_database(self):
if self.engine.dialect.name == "sqlite":
@ -340,6 +341,7 @@ class SQLAlchemyAdapter:
await connection.execute(drop_table_query)
metadata.clear()
except Exception as e:
print(f"Error deleting database: {e}")
logger.error(f"Error deleting database: {e}")
raise e
print("Database deleted successfully.")
logger.info("Database deleted successfully.")

View file

@ -1,49 +1,47 @@
from typing import Dict
from functools import lru_cache
class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str
def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
@lru_cache
def create_vector_engine(
embedding_engine,
vector_db_url: str,
vector_db_port: str,
vector_db_key: str,
vector_db_provider: str,
):
if vector_db_provider == "weaviate":
from .weaviate_db import WeaviateAdapter
if not (config["vector_db_url"] and config["vector_db_key"]):
if not (vector_db_url and vector_db_key):
raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter(
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
)
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
elif vector_db_provider == "qdrant":
if not (vector_db_url and vector_db_key):
raise EnvironmentError("Missing requred Qdrant credentials!")
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
elif config["vector_db_provider"] == "milvus":
elif vector_db_provider == "milvus":
from .milvus.MilvusAdapter import MilvusAdapter
if not config["vector_db_url"]:
if not vector_db_url:
raise EnvironmentError("Missing required Milvus credentials!")
return MilvusAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
elif config["vector_db_provider"] == "pgvector":
elif vector_db_provider == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
vector_db_key,
embedding_engine,
)
elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_port"]):
elif vector_db_provider == "falkordb":
if not (vector_db_url and vector_db_port):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url=config["vector_db_url"],
database_port=config["vector_db_port"],
database_url=vector_db_url,
database_port=vector_db_port,
embedding_engine=embedding_engine,
)
@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)

View file

@ -1,9 +1,7 @@
from .config import get_vectordb_config
from .embeddings import get_embedding_engine
from .create_vector_engine import create_vector_engine
from functools import lru_cache
@lru_cache
def get_vector_engine():
return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine())
return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict())

View file

@ -1,6 +1,6 @@
import asyncio
from typing import List, Optional, get_type_hints
from uuid import UUID
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))

View file

@ -1,2 +1,4 @@
from .config import get_llm_config
from .utils import get_max_chunk_tokens
from .utils import test_llm_connection
from .utils import test_embedding_connection

View file

@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str):
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens
async def test_llm_connection():
try:
llm_adapter = get_llm_client()
await llm_adapter.acreate_structured_output(
text_input="test",
system_prompt='Respond to me with the following string: "test"',
response_model=str,
)
except Exception as e:
logger.error(e)
logger.error("Connection to LLM could not be established.")
raise e
async def test_embedding_connection():
try:
await get_vector_engine().embedding_engine.embed_text("test")
except Exception as e:
logger.error(e)
logger.error("Connection to Embedding handler could not be established.")
raise e

View file

@ -1,10 +1,12 @@
from typing import Optional
import logging
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk
logger = logging.getLogger(__name__)
class TextChunker:
document = None
@ -76,7 +78,8 @@ class TextChunker:
},
)
except Exception as e:
print(e)
logger.error(e)
raise e
paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
@ -97,4 +100,5 @@ class TextChunker:
_metadata={"index_fields": ["text"]},
)
except Exception as e:
print(e)
logger.error(e)
raise e

View file

@ -11,3 +11,5 @@ from .get_data import get_data
# Delete
from .delete_dataset import delete_dataset
from .delete_data import delete_data
from .store_descriptive_metrics import store_descriptive_metrics

View file

@ -0,0 +1,50 @@
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.relational import get_relational_engine
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.modules.data.models import Data
from cognee.modules.data.models import GraphMetrics
import uuid
from cognee.infrastructure.databases.graph import get_graph_engine
async def fetch_token_count(db_engine) -> int:
"""
Fetches and sums token counts from the database.
Returns:
int: The total number of tokens across all documents.
"""
async with db_engine.get_async_session() as session:
token_count_sum = await session.execute(select(func.sum(Data.token_count)))
token_count_sum = token_count_sum.scalar()
return token_count_sum
async def store_descriptive_metrics(data_points: list[DataPoint], include_optional: bool):
db_engine = get_relational_engine()
graph_engine = await get_graph_engine()
graph_metrics = await graph_engine.get_graph_metrics(include_optional)
async with db_engine.get_async_session() as session:
metrics = GraphMetrics(
id=uuid.uuid4(),
num_tokens=await fetch_token_count(db_engine),
num_nodes=graph_metrics["num_nodes"],
num_edges=graph_metrics["num_edges"],
mean_degree=graph_metrics["mean_degree"],
edge_density=graph_metrics["edge_density"],
num_connected_components=graph_metrics["num_connected_components"],
sizes_of_connected_components=graph_metrics["sizes_of_connected_components"],
num_selfloops=graph_metrics["num_selfloops"],
diameter=graph_metrics["diameter"],
avg_shortest_path_length=graph_metrics["avg_shortest_path_length"],
avg_clustering=graph_metrics["avg_clustering"],
)
session.add(metrics)
await session.commit()
return data_points

View file

@ -1,4 +1,5 @@
from datetime import datetime, timezone
from sqlalchemy.sql import func
from sqlalchemy import Column, DateTime, Float, Integer, JSON, UUID
@ -7,7 +8,7 @@ from uuid import uuid4
class GraphMetrics(Base):
__tablename__ = "graph_metrics_table"
__tablename__ = "graph_metrics"
# TODO: Change ID to reflect unique id of graph database
id = Column(UUID, primary_key=True, default=uuid4)

View file

@ -1,3 +1,4 @@
from .Data import Data
from .Dataset import Dataset
from .DatasetData import DatasetData
from .GraphMetrics import GraphMetrics

View file

@ -34,5 +34,6 @@ async def detect_language(text: str):
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise e
return None

View file

@ -152,6 +152,7 @@ class CogneeGraph(CogneeAbstractGraph):
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
raise ex
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []

View file

@ -0,0 +1,18 @@
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
)
async def code_graph_retrieval(query, include_docs=False):
retrieved_codeparts, __ = await code_description_to_code_part_search(
query, include_docs=include_docs
)
return [
{
"name": codepart.attributes["file_path"],
"description": codepart.attributes["file_path"],
"content": codepart.attributes["source_code"],
}
for codepart in retrieved_codeparts
]

View file

@ -1,17 +1,18 @@
import asyncio
import logging
from typing import Set, List
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.api.v1.search import SearchType
from cognee.api.v1.search.search_v2 import search
from cognee.modules.search.methods import search
from cognee.infrastructure.llm.get_llm_client import get_llm_client
logger = logging.getLogger(__name__)
async def code_description_to_code_part_search(
query: str, include_docs=False, user: User = None, top_k=5
@ -39,7 +40,7 @@ async def code_description_to_code_part(
include_docs(bool): Boolean showing whether we have the docs in the graph or not
Returns:
Set[str]: A set of unique code parts matching the query.
List[str]: A set of unique code parts matching the query.
Raises:
ValueError: If arguments are invalid.
@ -66,7 +67,7 @@ async def code_description_to_code_part(
try:
if include_docs:
search_results = await search(SearchType.INSIGHTS, query_text=query)
search_results = await search(query_text=query, query_type="INSIGHTS", user=user)
concatenated_descriptions = " ".join(
obj["description"]
@ -98,6 +99,7 @@ async def code_description_to_code_part(
"id",
"type",
"text",
"file_path",
"source_code",
"pydantic_type",
],
@ -154,8 +156,9 @@ if __name__ == "__main__":
user = None
try:
results = await code_description_to_code_part_search(query, user)
print("Retrieved Code Parts:", results)
logger.debug("Retrieved Code Parts:", results)
except Exception as e:
print(f"An error occurred: {e}")
logger.error(f"An error occurred: {e}")
raise e
asyncio.run(main())

View file

@ -0,0 +1 @@
from .search import search

View file

@ -0,0 +1,66 @@
import json
from uuid import UUID
from typing import Callable
from cognee.exceptions import InvalidValueError
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
from ..operations import log_query, log_result
async def search(
query_text: str,
query_type: str,
datasets: list[str],
user: User,
):
query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user)
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
return filtered_search_results
async def specific_search(query_type: SearchType, query: str, user: User) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion,
SearchType.CODE: code_graph_retrieval,
}
search_task = search_tasks.get(query_type)
if search_task is None:
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results

View file

@ -0,0 +1,10 @@
from enum import Enum
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
CODE = "CODE"

View file

@ -0,0 +1 @@
from .SearchType import SearchType

View file

@ -23,6 +23,7 @@ class CodeFile(DataPoint):
class CodePart(DataPoint):
__tablename__ = "codepart"
file_path: str # file path
# part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None

View file

@ -30,6 +30,7 @@ def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts)
id=part_node_id,
type=part_type,
# part_of = code_file,
file_path=code_file.extracted_id[len(code_file.part_of.path) + 1 :],
source_code=code_part,
)
)

View file

@ -3,6 +3,7 @@ import asyncio
import sys
from contextlib import contextmanager
from pathlib import Path
from pickle import UnpicklingError
from typing import List, Dict, Optional
import aiofiles
@ -60,9 +61,22 @@ def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> Non
code_entity["full_name"] = getattr(result, "full_name", None)
code_entity["module_name"] = getattr(result, "module_name", None)
code_entity["module_path"] = getattr(result, "module_path", None)
except KeyError as e:
# TODO: See if there is a way to handle KeyError properly
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
return
except UnpicklingError as e:
# TODO: See if there is a way to handle UnpicklingError properly
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
return
except EOFError as e:
# TODO: See if there is a way to handle EOFError properly
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
return
except Exception as e:
# logging.warning(f"Failed to analyze code entity {code_entity['name']}: {e}")
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
raise e
async def _extract_dependencies(script_path: str) -> List[str]:

View file

@ -1,48 +0,0 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.infrastructure.databases.relational import get_relational_engine
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.modules.data.models import Data
from cognee.modules.data.models import GraphMetrics
import uuid
from cognee.infrastructure.databases.graph import get_graph_engine
async def fetch_token_count(db_engine) -> int:
"""
Fetches and sums token counts from the database.
Returns:
int: The total number of tokens across all documents.
"""
async with db_engine.get_async_session() as session:
token_count_sum = await session.execute(select(func.sum(Data.token_count)))
token_count_sum = token_count_sum.scalar()
return token_count_sum
async def calculate_graph_metrics(graph_data):
nodes, edges = graph_data
graph_metrics = {
"num_nodes": len(nodes),
"num_edges": len(edges),
}
return graph_metrics
async def store_descriptive_metrics(data_points: list[DataPoint]):
db_engine = get_relational_engine()
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
token_count_sum = await fetch_token_count(db_engine)
graph_metrics = await calculate_graph_metrics(graph_data)
table_name = "graph_metrics_table"
metrics_dict = {"id": uuid.uuid4(), "num_tokens": token_count_sum} | graph_metrics
await db_engine.insert_data(table_name, metrics_dict)
return data_points

View file

@ -2,8 +2,8 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.shared.utils import render_graph
from cognee.modules.search.types import SearchType
# from cognee.shared.utils import render_graph
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -4,7 +4,7 @@ import pathlib
import cognee
from cognee.modules.data.models import Data
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.users.methods import get_default_user

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -8,7 +8,6 @@ from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.description_to_codepart_search import (

View file

@ -1,5 +1,5 @@
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string

View file

@ -94,9 +94,7 @@ async def cognify_search_base_rag(content: str, context: str):
async def cognify_search_graph(content: str, context: str):
from cognee.api.v1.search import search, SearchType
params = {"query": "Donald Trump"}
results = await search(SearchType.INSIGHTS, params)
results = await search(SearchType.INSIGHTS, query_text="Donald Trump")
print("results", results)
return results

View file

@ -5,8 +5,8 @@ import asyncio
import cognee
import signal
from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
from cognee.modules.search.types import SearchType
app = modal.App("cognee-runner")

8
poetry.lock generated
View file

@ -1071,7 +1071,6 @@ files = [
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"},
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"},
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"},
{file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"},
@ -1082,7 +1081,6 @@ files = [
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"},
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"},
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"},
{file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"},
@ -3095,6 +3093,8 @@ optional = false
python-versions = "*"
files = [
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
]
[package.dependencies]
@ -3694,13 +3694,17 @@ proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "
[[package]]
name = "llama-index-core"
version = "0.12.14"
description = "Interface between LLMs and your data"
optional = true
python-versions = "<4.0,>=3.9"
files = [
{file = "llama_index_core-0.12.14-py3-none-any.whl", hash = "sha256:6fdb30e3fadf98e7df75f9db5d06f6a7f8503ca545a71e048d786ff88012bd50"},
{file = "llama_index_core-0.12.14.tar.gz", hash = "sha256:378bbf5bf4d1a8c692d3a980c1a6ed3be7a9afb676a4960429dea15f62d06cd3"},
]
[package.dependencies]

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "cognee"
version = "0.1.23"
version = "0.1.24"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = ["Vasilije Markovic", "Boris Arzentar"]
readme = "README.md"