Merge remote-tracking branch 'origin/dev'
This commit is contained in:
commit
65d51d4aa7
54 changed files with 506 additions and 384 deletions
|
|
@ -130,7 +130,7 @@ This script will run the default pipeline:
|
|||
```python
|
||||
import cognee
|
||||
import asyncio
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
async def main():
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ readme = "README.md"
|
|||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"cognee",
|
||||
"mcp==1.2.0",
|
||||
"cognee[codegraph]",
|
||||
"mcp==1.1.3",
|
||||
]
|
||||
|
||||
[[project.authors]]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import cognee
|
||||
import logging
|
||||
|
|
@ -8,8 +9,10 @@ from contextlib import redirect_stderr, redirect_stdout
|
|||
import mcp.types as types
|
||||
from mcp.server import Server, NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
mcp = Server("cognee")
|
||||
|
||||
|
|
@ -41,6 +44,19 @@ async def list_tools() -> list[types.Tool]:
|
|||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="codify",
|
||||
description="Transforms codebase into knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo_path": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["repo_path"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="search",
|
||||
description="Searches for information in knowledge graph",
|
||||
|
|
@ -51,6 +67,10 @@ async def list_tools() -> list[types.Tool]:
|
|||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
},
|
||||
"search_type": {
|
||||
"type": "string",
|
||||
"description": "The type of search to perform (e.g., INSIGHTS, CODE)",
|
||||
},
|
||||
},
|
||||
"required": ["search_query"],
|
||||
},
|
||||
|
|
@ -72,15 +92,21 @@ async def call_tools(name: str, arguments: dict) -> list[types.TextContent]:
|
|||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
if name == "cognify":
|
||||
await cognify(
|
||||
cognify(
|
||||
text=arguments["text"],
|
||||
graph_model_file=arguments.get("graph_model_file", None),
|
||||
graph_model_name=arguments.get("graph_model_name", None),
|
||||
)
|
||||
|
||||
return [types.TextContent(type="text", text="Ingested")]
|
||||
if name == "codify":
|
||||
await codify(arguments.get("repo_path"))
|
||||
|
||||
return [types.TextContent(type="text", text="Indexed")]
|
||||
elif name == "search":
|
||||
search_results = await search(arguments["search_query"])
|
||||
search_results = await search(
|
||||
arguments["search_query"], arguments["search_type"]
|
||||
)
|
||||
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
elif name == "prune":
|
||||
|
|
@ -102,21 +128,30 @@ async def cognify(text: str, graph_model_file: str = None, graph_model_name: str
|
|||
await cognee.add(text)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
asyncio.create_task(cognee.cognify(graph_model=graph_model))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
|
||||
async def search(search_query: str) -> str:
|
||||
async def codify(repo_path: str):
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
print(result)
|
||||
|
||||
|
||||
async def search(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType[search_type.upper()], query_text=search_query
|
||||
)
|
||||
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
|
||||
return results
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
else:
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
|
||||
|
||||
async def prune() -> str:
|
||||
async def prune():
|
||||
"""Reset the knowledge graph"""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
|
|||
90
cognee-mcp/uv.lock
generated
90
cognee-mcp/uv.lock
generated
|
|
@ -474,7 +474,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.1.22"
|
||||
version = "0.1.23"
|
||||
source = { directory = "../" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
@ -516,11 +516,16 @@ dependencies = [
|
|||
{ name = "sqlalchemy" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "transformers" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
codegraph = [
|
||||
{ name = "jedi" },
|
||||
{ name = "parso" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.2.1,<24.0.0" },
|
||||
|
|
@ -539,6 +544,7 @@ requires-dist = [
|
|||
{ name = "fastapi", specifier = "==0.115.7" },
|
||||
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" },
|
||||
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
|
||||
{ name = "google-generativeai", marker = "extra == 'gemini'", specifier = ">=0.8.4,<0.9.0" },
|
||||
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
|
||||
{ name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" },
|
||||
{ name = "gunicorn", specifier = ">=20.1.0,<21.0.0" },
|
||||
|
|
@ -578,7 +584,7 @@ requires-dist = [
|
|||
{ name = "sqlalchemy", specifier = "==2.0.36" },
|
||||
{ name = "tenacity", specifier = ">=9.0.0,<10.0.0" },
|
||||
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||
{ name = "transformers", specifier = ">=4.46.3,<5.0.0" },
|
||||
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5.0.0" },
|
||||
{ name = "typing-extensions", specifier = "==4.12.2" },
|
||||
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" },
|
||||
{ name = "uvicorn", specifier = "==0.34.0" },
|
||||
|
|
@ -590,14 +596,14 @@ name = "cognee-mcp"
|
|||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "cognee" },
|
||||
{ name = "cognee", extra = ["codegraph"] },
|
||||
{ name = "mcp" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "cognee", directory = "../" },
|
||||
{ name = "mcp", specifier = "==1.2.0" },
|
||||
{ name = "cognee", extras = ["codegraph"], directory = "../" },
|
||||
{ name = "mcp", specifier = "==1.1.3" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1307,6 +1313,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c5/82/fd319382c1a33d7021cf151007b4cbd5daddf09d9ca5fb670e476668f9fc/instructor-1.7.2-py3-none-any.whl", hash = "sha256:cb43d27f6d7631c31762b936b2fcb44d2a3f9d8a020430a0f4d3484604ffb95b", size = 71353 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jedi"
|
||||
version = "0.19.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "parso" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.5"
|
||||
|
|
@ -1739,21 +1757,19 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.2.0"
|
||||
version = "1.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "httpx" },
|
||||
{ name = "httpx-sse" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "starlette" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ab/a5/b08dc846ebedae9f17ced878e6975826e90e448cd4592f532f6a88a925a7/mcp-1.2.0.tar.gz", hash = "sha256:2b06c7ece98d6ea9e6379caa38d74b432385c338fb530cb82e2c70ea7add94f5", size = 102973 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f7/60/66ebfd280b197f9a9d074c9e46cb1ac3186a32d12e6bd0425c24fe7cf7e8/mcp-1.1.3.tar.gz", hash = "sha256:af11018b8e9153cdd25f3722ec639fe7a462c00213a330fd6f593968341a9883", size = 57903 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/84/fca78f19ac8ce6c53ba416247c71baa53a9e791e98d3c81edbc20a77d6d1/mcp-1.2.0-py3-none-any.whl", hash = "sha256:1d0e77d8c14955a5aea1f5aa1f444c8e531c09355c829b20e42f7a142bc0755f", size = 66468 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/08/cfcfa13e41f8d27503c51a8cbf1939d720073ace92469d08655bb5de1b24/mcp-1.1.3-py3-none-any.whl", hash = "sha256:71462d6cd7c06c14689dfcf110ff22286ba1b608cfc3515c0a5cbe33d131731a", size = 36997 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2083,6 +2099,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parso"
|
||||
version = "0.8.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathvalidate"
|
||||
version = "3.2.3"
|
||||
|
|
@ -2850,28 +2875,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/5f/ce/22673f4a85ccc640735b4f8d12178a0f41b5d3c6eda7f33756d10ce56901/s3transfer-0.11.1-py3-none-any.whl", hash = "sha256:8fa0aa48177be1f3425176dfe1ab85dcd3d962df603c3dbfc585e6bf857ef0ff", size = 84111 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/4f/2ef9ef1766f8c194b01b67a63a444d2e557c8fe1d82faf3ebd85f370a917/safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8", size = 66957 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/96/d1/017e31e75e274492a11a456a9e7c171f8f7911fe50735b4ec6ff37221220/safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2", size = 427067 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/84/e9d3ff57ae50dd0028f301c9ee064e5087fe8b00e55696677a0413c377a7/safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae", size = 408856 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/1d/fe95f5dd73db16757b11915e8a5106337663182d0381811c81993e0014a9/safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c", size = 450088 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/21/e527961b12d5ab528c6e47b92d5f57f33563c28a972750b238b871924e49/safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e", size = 458966 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/8b/1a037d7a57f86837c0b41905040369aea7d8ca1ec4b2a77592372b2ec380/safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869", size = 509915 },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/3d/03dd5cfd33839df0ee3f4581a20bd09c40246d169c0e4518f20b21d5f077/safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a", size = 527664 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/dc/8952caafa9a10a3c0f40fa86bacf3190ae7f55fa5eef87415b97b29cb97f/safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5", size = 461978 },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/da/82de1fcf1194e3dbefd4faa92dc98b33c06bed5d67890e0962dd98e18287/safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975", size = 491253 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/9a/d90e273c25f90c3ba1b0196a972003786f04c39e302fbd6649325b1272bb/safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e", size = 628644 },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/3c/acb23e05aa34b4f5edd2e7f393f8e6480fbccd10601ab42cd03a57d4ab5f/safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f", size = 721648 },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/45/eaa3dba5253a7c6931230dc961641455710ab231f8a89cb3c4c2af70f8c8/safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf", size = 659588 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/71/2f9851164f821064d43b481ddbea0149c2d676c4f4e077b178e7eeaa6660/safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76", size = 632533 },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/f1/5680e2ef61d9c61454fad82c344f0e40b8741a9dbd1e31484f0d31a9b1c3/safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2", size = 291167 },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scikit-learn"
|
||||
version = "1.6.1"
|
||||
|
|
@ -3348,27 +3351,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.48.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "regex" },
|
||||
{ name = "requests" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/21/6b/caf620fae7fbf35947c81e7dd0834493b9ad9b71bb9e433025ac7a07e79a/transformers-4.48.1.tar.gz", hash = "sha256:7c1931facc3ee8adcbf86fc7a87461d54c1e40eca3bb57fef1ee9f3ecd32187e", size = 8365872 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl", hash = "sha256:24be0564b0a36d9e433d9a65de248f1545b6f6edce1737669605eb6a8141bbbb", size = 9665001 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typer"
|
||||
version = "0.15.1"
|
||||
|
|
|
|||
|
|
@ -188,6 +188,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
|||
except Exception as e:
|
||||
logger.exception(f"Failed to start server: {e}")
|
||||
# Here you could add any cleanup code or error recovery code.
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -16,9 +16,22 @@ async def add(
|
|||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
):
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
# Initialize first_run attribute if it doesn't exist
|
||||
if not hasattr(add, "first_run"):
|
||||
add.first_run = True
|
||||
|
||||
if add.first_run:
|
||||
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
|
||||
|
||||
# Test LLM and Embedding configuration once before running Cognee
|
||||
await test_llm_connection()
|
||||
await test_embedding_connection()
|
||||
add.first_run = False # Update flag after first run
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from cognee.tasks.documents import (
|
|||
)
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage.descriptive_metrics import store_descriptive_metrics
|
||||
from cognee.modules.data.methods import store_descriptive_metrics
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
|
||||
|
|
@ -165,7 +165,7 @@ async def get_default_tasks(
|
|||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(store_descriptive_metrics),
|
||||
Task(store_descriptive_metrics, include_optional=True),
|
||||
]
|
||||
except Exception as error:
|
||||
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
code_description_to_code_part_search,
|
||||
|
|
@ -7,14 +7,14 @@ from cognee.modules.retrieval.description_to_codepart_search import (
|
|||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(BaseModel):
|
||||
class CodePipelineIndexPayloadDTO(InDTO):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(BaseModel):
|
||||
class CodePipelineRetrievePayloadDTO(InDTO):
|
||||
query: str
|
||||
fullInput: str
|
||||
full_input: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
|
|
@ -34,9 +34,9 @@ def get_code_pipeline_router() -> APIRouter:
|
|||
"""This endpoint is responsible for retrieving the context."""
|
||||
try:
|
||||
query = (
|
||||
payload.fullInput.replace("cognee ", "")
|
||||
if payload.fullInput.startswith("cognee ")
|
||||
else payload.fullInput
|
||||
payload.full_input.replace("cognee ", "")
|
||||
if payload.full_input.startswith("cognee ")
|
||||
else payload.full_input
|
||||
)
|
||||
|
||||
retrieved_codeparts, __ = await code_description_to_code_part_search(
|
||||
|
|
@ -45,8 +45,8 @@ def get_code_pipeline_router() -> APIRouter:
|
|||
|
||||
return [
|
||||
{
|
||||
"name": codepart.attributes["id"],
|
||||
"description": codepart.attributes["id"],
|
||||
"name": codepart.attributes["file_path"],
|
||||
"description": codepart.attributes["file_path"],
|
||||
"content": codepart.attributes["source_code"],
|
||||
}
|
||||
for codepart in retrieved_codeparts
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from uuid import UUID
|
|||
from datetime import datetime
|
||||
from fastapi import Depends, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.operations import get_history
|
||||
|
|
|
|||
|
|
@ -1,100 +0,0 @@
|
|||
"""This module contains the search function that is used to search for nodes in the graph."""
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Callable, List
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from cognee.modules.search.graph import search_cypher
|
||||
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||
from cognee.modules.search.graph.search_summary import search_summary
|
||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||
|
||||
from cognee.exceptions import UserNotFoundError
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
ADJACENT = "ADJACENT"
|
||||
TRAVERSE = "TRAVERSE"
|
||||
SIMILARITY = "SIMILARITY"
|
||||
SUMMARY = "SUMMARY"
|
||||
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
|
||||
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
|
||||
DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",)
|
||||
CYPHER = "CYPHER"
|
||||
|
||||
@staticmethod
|
||||
def from_str(name: str):
|
||||
try:
|
||||
return SearchType[name.upper()]
|
||||
except KeyError as error:
|
||||
raise ValueError(f"{name} is not a valid SearchType") from error
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
search_type: SearchType
|
||||
params: Dict[str, Any]
|
||||
|
||||
@field_validator("search_type", mode="before")
|
||||
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
|
||||
if isinstance(value, str):
|
||||
return SearchType.from_str(value)
|
||||
return value
|
||||
|
||||
|
||||
async def search(search_type: str, params: Dict[str, Any], user: User = None) -> List:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise UserNotFoundError
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
search_params = SearchParameters(search_type=search_type, params=params)
|
||||
search_results = await specific_search([search_params], user)
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
||||
return filtered_search_results
|
||||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters], user) -> List:
|
||||
search_functions: Dict[SearchType, Callable] = {
|
||||
SearchType.ADJACENT: search_adjacent,
|
||||
SearchType.SUMMARY: search_summary,
|
||||
SearchType.CYPHER: search_cypher,
|
||||
SearchType.TRAVERSE: search_traverse,
|
||||
SearchType.SIMILARITY: search_similarity,
|
||||
}
|
||||
|
||||
search_tasks = []
|
||||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
for search_param in query_params:
|
||||
search_func = search_functions.get(search_param.search_type)
|
||||
if search_func:
|
||||
# Schedule the coroutine for execution and store the task
|
||||
task = search_func(**search_param.params)
|
||||
search_tasks.append(task)
|
||||
|
||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
return search_results[0] if len(search_results) == 1 else search_results
|
||||
|
|
@ -1,29 +1,10 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, Union
|
||||
from typing import Union
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.tasks.chunks import query_chunks
|
||||
from cognee.tasks.graph import query_graph_connections
|
||||
from cognee.tasks.summarization import query_summaries
|
||||
from cognee.tasks.completion import query_completion
|
||||
from cognee.tasks.completion import graph_query_completion
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
SUMMARIES = "SUMMARIES"
|
||||
INSIGHTS = "INSIGHTS"
|
||||
CHUNKS = "CHUNKS"
|
||||
COMPLETION = "COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -42,43 +23,6 @@ async def search(
|
|||
if user is None:
|
||||
raise UserNotFoundError
|
||||
|
||||
query = await log_query(query_text, str(query_type), user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(query_type, query_text, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
||||
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
|
||||
filtered_search_results = await search_function(query_text, query_type, datasets, user)
|
||||
|
||||
return filtered_search_results
|
||||
|
||||
|
||||
async def specific_search(query_type: SearchType, query: str, user) -> list:
|
||||
search_tasks: Dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: query_summaries,
|
||||
SearchType.INSIGHTS: query_graph_connections,
|
||||
SearchType.CHUNKS: query_chunks,
|
||||
SearchType.COMPLETION: query_completion,
|
||||
SearchType.GRAPH_COMPLETION: graph_query_completion,
|
||||
}
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
||||
if search_task is None:
|
||||
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
|
||||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await search_task(query)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -25,11 +25,24 @@ class GraphConfig(BaseSettings):
|
|||
return {
|
||||
"graph_filename": self.graph_filename,
|
||||
"graph_database_provider": self.graph_database_provider,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_database_url": self.graph_database_url,
|
||||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
return {
|
||||
"graph_database_provider": self.graph_database_provider,
|
||||
"graph_database_url": self.graph_database_url,
|
||||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface
|
|||
|
||||
async def get_graph_engine() -> GraphDBInterface:
|
||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
graph_client = create_graph_engine()
|
||||
config = get_graph_config()
|
||||
|
||||
graph_client = create_graph_engine(**get_graph_config().to_hashable_dict())
|
||||
|
||||
# Async functions can't be cached. After creating and caching the graph engine
|
||||
# handle all necessary async operations for different graph types bellow.
|
||||
config = get_graph_config()
|
||||
|
||||
# Handle loading of graph for NetworkX
|
||||
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
|
||||
await graph_client.load_graph_from_file()
|
||||
|
|
@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface:
|
|||
|
||||
|
||||
@lru_cache
|
||||
def create_graph_engine() -> GraphDBInterface:
|
||||
def create_graph_engine(
|
||||
graph_database_provider,
|
||||
graph_database_url,
|
||||
graph_database_username,
|
||||
graph_database_password,
|
||||
graph_database_port,
|
||||
graph_file_path,
|
||||
):
|
||||
"""Factory function to create the appropriate graph client based on the graph type."""
|
||||
config = get_graph_config()
|
||||
|
||||
if config.graph_database_provider == "neo4j":
|
||||
if not (
|
||||
config.graph_database_url
|
||||
and config.graph_database_username
|
||||
and config.graph_database_password
|
||||
):
|
||||
if graph_database_provider == "neo4j":
|
||||
if not (graph_database_url and graph_database_username and graph_database_password):
|
||||
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
return Neo4jAdapter(
|
||||
graph_database_url=config.graph_database_url,
|
||||
graph_database_username=config.graph_database_username,
|
||||
graph_database_password=config.graph_database_password,
|
||||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
)
|
||||
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
if not (config.graph_database_url and config.graph_database_port):
|
||||
elif graph_database_provider == "falkordb":
|
||||
if not (graph_database_url and graph_database_port):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
|
@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface:
|
|||
embedding_engine = get_embedding_engine()
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=config.graph_database_url,
|
||||
database_port=config.graph_database_port,
|
||||
database_url=graph_database_url,
|
||||
database_port=graph_database_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=config.graph_file_path)
|
||||
graph_client = NetworkXAdapter(filename=graph_file_path)
|
||||
|
||||
return graph_client
|
||||
|
|
|
|||
|
|
@ -54,3 +54,7 @@ class GraphDBInterface(Protocol):
|
|||
@abstractmethod
|
||||
async def get_graph_data(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph_metrics(self, include_optional):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -530,3 +530,17 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False):
|
||||
return {
|
||||
"num_nodes": -1,
|
||||
"num_edges": -1,
|
||||
"mean_degree": -1,
|
||||
"edge_density": -1,
|
||||
"num_connected_components": -1,
|
||||
"sizes_of_connected_components": -1,
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ import networkx as nx
|
|||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger("NetworkXAdapter")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NetworkXAdapter(GraphDBInterface):
|
||||
|
|
@ -269,8 +270,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if not isinstance(node["id"], UUID):
|
||||
node["id"] = UUID(node["id"])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
if isinstance(node.get("updated_at"), int):
|
||||
node["updated_at"] = datetime.fromtimestamp(
|
||||
|
|
@ -298,8 +299,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
if isinstance(edge["updated_at"], int): # Handle timestamp in milliseconds
|
||||
edge["updated_at"] = datetime.fromtimestamp(
|
||||
|
|
@ -327,8 +328,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error("Failed to load graph from file: %s", file_path)
|
||||
raise e
|
||||
|
||||
async def delete_graph(self, file_path: str = None):
|
||||
"""Asynchronously delete the graph file from the filesystem."""
|
||||
|
|
@ -344,6 +346,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
logger.info("Graph deleted successfully.")
|
||||
except Exception as error:
|
||||
logger.error("Failed to delete graph: %s", error)
|
||||
raise error
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
|
|
@ -385,3 +388,64 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
]
|
||||
|
||||
return filtered_nodes, filtered_edges
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False):
|
||||
graph = self.graph
|
||||
|
||||
def _get_mean_degree(graph):
|
||||
degrees = [d for _, d in graph.degree()]
|
||||
return np.mean(degrees) if degrees else 0
|
||||
|
||||
def _get_edge_density(graph):
|
||||
num_nodes = graph.number_of_nodes()
|
||||
num_edges = graph.number_of_edges()
|
||||
num_possible_edges = num_nodes * (num_nodes - 1)
|
||||
edge_density = num_edges / num_possible_edges if num_possible_edges > 0 else 0
|
||||
return edge_density
|
||||
|
||||
def _get_diameter(graph):
|
||||
if nx.is_strongly_connected(graph):
|
||||
return nx.diameter(graph.to_undirected())
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_avg_shortest_path_length(graph):
|
||||
if nx.is_strongly_connected(graph):
|
||||
return nx.average_shortest_path_length(graph)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_avg_clustering(graph):
|
||||
try:
|
||||
return nx.average_clustering(nx.DiGraph(graph))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to calculate clustering coefficient: %s", e)
|
||||
return None
|
||||
|
||||
mandatory_metrics = {
|
||||
"num_nodes": graph.number_of_nodes(),
|
||||
"num_edges": graph.number_of_edges(),
|
||||
"mean_degree": _get_mean_degree(graph),
|
||||
"edge_density": _get_edge_density(graph),
|
||||
"num_connected_components": nx.number_weakly_connected_components(graph),
|
||||
"sizes_of_connected_components": [
|
||||
len(c) for c in nx.weakly_connected_components(graph)
|
||||
],
|
||||
}
|
||||
|
||||
if include_optional:
|
||||
optional_metrics = {
|
||||
"num_selfloops": sum(1 for u, v in graph.edges() if u == v),
|
||||
"diameter": _get_diameter(graph),
|
||||
"avg_shortest_path_length": _get_avg_shortest_path_length(graph),
|
||||
"avg_clustering": _get_avg_clustering(graph),
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
return mandatory_metrics | optional_metrics
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_relational_engine(
|
||||
db_path: str,
|
||||
db_name: str,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
# from functools import lru_cache
|
||||
|
||||
from .config import get_relational_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
# @lru_cache
|
||||
def get_relational_engine():
|
||||
relational_config = get_relational_config()
|
||||
|
||||
|
|
|
|||
|
|
@ -303,9 +303,10 @@ class SQLAlchemyAdapter:
|
|||
await connection.execute(text("DROP TABLE IF EXISTS group_permission CASCADE"))
|
||||
await connection.execute(text("DROP TABLE IF EXISTS permissions CASCADE"))
|
||||
# Add more DROP TABLE statements for other tables as needed
|
||||
print("Database tables dropped successfully.")
|
||||
logger.debug("Database tables dropped successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error dropping database tables: {e}")
|
||||
logger.error(f"Error dropping database tables: {e}")
|
||||
raise e
|
||||
|
||||
async def create_database(self):
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
|
|
@ -340,6 +341,7 @@ class SQLAlchemyAdapter:
|
|||
await connection.execute(drop_table_query)
|
||||
metadata.clear()
|
||||
except Exception as e:
|
||||
print(f"Error deleting database: {e}")
|
||||
logger.error(f"Error deleting database: {e}")
|
||||
raise e
|
||||
|
||||
print("Database deleted successfully.")
|
||||
logger.info("Database deleted successfully.")
|
||||
|
|
|
|||
|
|
@ -1,49 +1,47 @@
|
|||
from typing import Dict
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class VectorConfig(Dict):
|
||||
vector_db_url: str
|
||||
vector_db_port: str
|
||||
vector_db_key: str
|
||||
vector_db_provider: str
|
||||
|
||||
|
||||
def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||
if config["vector_db_provider"] == "weaviate":
|
||||
@lru_cache
|
||||
def create_vector_engine(
|
||||
embedding_engine,
|
||||
vector_db_url: str,
|
||||
vector_db_port: str,
|
||||
vector_db_key: str,
|
||||
vector_db_provider: str,
|
||||
):
|
||||
if vector_db_provider == "weaviate":
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(
|
||||
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
|
||||
)
|
||||
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
|
||||
|
||||
elif config["vector_db_provider"] == "qdrant":
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
elif vector_db_provider == "qdrant":
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "milvus":
|
||||
elif vector_db_provider == "milvus":
|
||||
from .milvus.MilvusAdapter import MilvusAdapter
|
||||
|
||||
if not config["vector_db_url"]:
|
||||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing required Milvus credentials!")
|
||||
|
||||
return MilvusAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
elif vector_db_provider == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
# Get configuration for postgres database
|
||||
|
|
@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
config["vector_db_key"],
|
||||
vector_db_key,
|
||||
embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "falkordb":
|
||||
if not (config["vector_db_url"] and config["vector_db_port"]):
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=config["vector_db_url"],
|
||||
database_port=config["vector_db_port"],
|
||||
database_url=vector_db_url,
|
||||
database_port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
return LanceDBAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from .config import get_vectordb_config
|
||||
from .embeddings import get_embedding_engine
|
||||
from .create_vector_engine import create_vector_engine
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vector_engine():
|
||||
return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine())
|
||||
return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, get_type_hints
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
|
|
@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
|
@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from .config import get_llm_config
|
||||
from .utils import get_max_chunk_tokens
|
||||
from .utils import test_llm_connection
|
||||
from .utils import test_embedding_connection
|
||||
|
|
|
|||
|
|
@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str):
|
|||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
|
||||
|
||||
async def test_llm_connection():
|
||||
try:
|
||||
llm_adapter = get_llm_client()
|
||||
await llm_adapter.acreate_structured_output(
|
||||
text_input="test",
|
||||
system_prompt='Respond to me with the following string: "test"',
|
||||
response_model=str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to LLM could not be established.")
|
||||
raise e
|
||||
|
||||
|
||||
async def test_embedding_connection():
|
||||
try:
|
||||
await get_vector_engine().embedding_engine.embed_text("test")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to Embedding handler could not be established.")
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextChunker:
|
||||
document = None
|
||||
|
|
@ -76,7 +78,8 @@ class TextChunker:
|
|||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
raise e
|
||||
paragraph_chunks = [chunk_data]
|
||||
self.chunk_size = chunk_data["word_count"]
|
||||
self.token_count = chunk_data["token_count"]
|
||||
|
|
@ -97,4 +100,5 @@ class TextChunker:
|
|||
_metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -11,3 +11,5 @@ from .get_data import get_data
|
|||
# Delete
|
||||
from .delete_dataset import delete_dataset
|
||||
from .delete_data import delete_data
|
||||
|
||||
from .store_descriptive_metrics import store_descriptive_metrics
|
||||
|
|
|
|||
50
cognee/modules/data/methods/store_descriptive_metrics.py
Normal file
50
cognee/modules/data/methods/store_descriptive_metrics.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.data.models import GraphMetrics
|
||||
import uuid
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
||||
async def fetch_token_count(db_engine) -> int:
|
||||
"""
|
||||
Fetches and sums token counts from the database.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens across all documents.
|
||||
"""
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
token_count_sum = await session.execute(select(func.sum(Data.token_count)))
|
||||
token_count_sum = token_count_sum.scalar()
|
||||
|
||||
return token_count_sum
|
||||
|
||||
|
||||
async def store_descriptive_metrics(data_points: list[DataPoint], include_optional: bool):
|
||||
db_engine = get_relational_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_metrics = await graph_engine.get_graph_metrics(include_optional)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
metrics = GraphMetrics(
|
||||
id=uuid.uuid4(),
|
||||
num_tokens=await fetch_token_count(db_engine),
|
||||
num_nodes=graph_metrics["num_nodes"],
|
||||
num_edges=graph_metrics["num_edges"],
|
||||
mean_degree=graph_metrics["mean_degree"],
|
||||
edge_density=graph_metrics["edge_density"],
|
||||
num_connected_components=graph_metrics["num_connected_components"],
|
||||
sizes_of_connected_components=graph_metrics["sizes_of_connected_components"],
|
||||
num_selfloops=graph_metrics["num_selfloops"],
|
||||
diameter=graph_metrics["diameter"],
|
||||
avg_shortest_path_length=graph_metrics["avg_shortest_path_length"],
|
||||
avg_clustering=graph_metrics["avg_clustering"],
|
||||
)
|
||||
|
||||
session.add(metrics)
|
||||
await session.commit()
|
||||
|
||||
return data_points
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, JSON, UUID
|
||||
|
||||
|
|
@ -7,7 +8,7 @@ from uuid import uuid4
|
|||
|
||||
|
||||
class GraphMetrics(Base):
|
||||
__tablename__ = "graph_metrics_table"
|
||||
__tablename__ = "graph_metrics"
|
||||
|
||||
# TODO: Change ID to reflect unique id of graph database
|
||||
id = Column(UUID, primary_key=True, default=uuid4)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .Data import Data
|
||||
from .Dataset import Dataset
|
||||
from .DatasetData import DatasetData
|
||||
from .GraphMetrics import GraphMetrics
|
||||
|
|
|
|||
|
|
@ -34,5 +34,6 @@ async def detect_language(text: str):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise e
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
except Exception as ex:
|
||||
print(f"Error mapping vector distances to edges: {ex}")
|
||||
raise ex
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||
min_heap = []
|
||||
|
|
|
|||
18
cognee/modules/retrieval/code_graph_retrieval.py
Normal file
18
cognee/modules/retrieval/code_graph_retrieval.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
code_description_to_code_part_search,
|
||||
)
|
||||
|
||||
|
||||
async def code_graph_retrieval(query, include_docs=False):
|
||||
retrieved_codeparts, __ = await code_description_to_code_part_search(
|
||||
query, include_docs=include_docs
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": codepart.attributes["file_path"],
|
||||
"description": codepart.attributes["file_path"],
|
||||
"content": codepart.attributes["source_code"],
|
||||
}
|
||||
for codepart in retrieved_codeparts
|
||||
]
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import Set, List
|
||||
from typing import List
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.api.v1.search.search_v2 import search
|
||||
from cognee.modules.search.methods import search
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def code_description_to_code_part_search(
|
||||
query: str, include_docs=False, user: User = None, top_k=5
|
||||
|
|
@ -39,7 +40,7 @@ async def code_description_to_code_part(
|
|||
include_docs(bool): Boolean showing whether we have the docs in the graph or not
|
||||
|
||||
Returns:
|
||||
Set[str]: A set of unique code parts matching the query.
|
||||
List[str]: A set of unique code parts matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If arguments are invalid.
|
||||
|
|
@ -66,7 +67,7 @@ async def code_description_to_code_part(
|
|||
|
||||
try:
|
||||
if include_docs:
|
||||
search_results = await search(SearchType.INSIGHTS, query_text=query)
|
||||
search_results = await search(query_text=query, query_type="INSIGHTS", user=user)
|
||||
|
||||
concatenated_descriptions = " ".join(
|
||||
obj["description"]
|
||||
|
|
@ -98,6 +99,7 @@ async def code_description_to_code_part(
|
|||
"id",
|
||||
"type",
|
||||
"text",
|
||||
"file_path",
|
||||
"source_code",
|
||||
"pydantic_type",
|
||||
],
|
||||
|
|
@ -154,8 +156,9 @@ if __name__ == "__main__":
|
|||
user = None
|
||||
try:
|
||||
results = await code_description_to_code_part_search(query, user)
|
||||
print("Retrieved Code Parts:", results)
|
||||
logger.debug("Retrieved Code Parts:", results)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
logger.error(f"An error occurred: {e}")
|
||||
raise e
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
1
cognee/modules/search/methods/__init__.py
Normal file
1
cognee/modules/search/methods/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .search import search
|
||||
66
cognee/modules/search/methods/search.py
Normal file
66
cognee/modules/search/methods/search.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from typing import Callable
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.chunks import query_chunks
|
||||
from cognee.tasks.graph import query_graph_connections
|
||||
from cognee.tasks.summarization import query_summaries
|
||||
from cognee.tasks.completion import query_completion
|
||||
from cognee.tasks.completion import graph_query_completion
|
||||
from ..operations import log_query, log_result
|
||||
|
||||
|
||||
async def search(
|
||||
query_text: str,
|
||||
query_type: str,
|
||||
datasets: list[str],
|
||||
user: User,
|
||||
):
|
||||
query = await log_query(query_text, str(query_type), user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(query_type, query_text, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
||||
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
|
||||
|
||||
return filtered_search_results
|
||||
|
||||
|
||||
async def specific_search(query_type: SearchType, query: str, user: User) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: query_summaries,
|
||||
SearchType.INSIGHTS: query_graph_connections,
|
||||
SearchType.CHUNKS: query_chunks,
|
||||
SearchType.COMPLETION: query_completion,
|
||||
SearchType.GRAPH_COMPLETION: graph_query_completion,
|
||||
SearchType.CODE: code_graph_retrieval,
|
||||
}
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
||||
if search_task is None:
|
||||
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
|
||||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await search_task(query)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
return results
|
||||
10
cognee/modules/search/types/SearchType.py
Normal file
10
cognee/modules/search/types/SearchType.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
SUMMARIES = "SUMMARIES"
|
||||
INSIGHTS = "INSIGHTS"
|
||||
CHUNKS = "CHUNKS"
|
||||
COMPLETION = "COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
CODE = "CODE"
|
||||
1
cognee/modules/search/types/__init__.py
Normal file
1
cognee/modules/search/types/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .SearchType import SearchType
|
||||
|
|
@ -23,6 +23,7 @@ class CodeFile(DataPoint):
|
|||
|
||||
class CodePart(DataPoint):
|
||||
__tablename__ = "codepart"
|
||||
file_path: str # file path
|
||||
# part_of: Optional[CodeFile] = None
|
||||
pydantic_type: str = "CodePart"
|
||||
source_code: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts)
|
|||
id=part_node_id,
|
||||
type=part_type,
|
||||
# part_of = code_file,
|
||||
file_path=code_file.extracted_id[len(code_file.part_of.path) + 1 :],
|
||||
source_code=code_part,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import sys
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from pickle import UnpicklingError
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
|
|
@ -60,9 +61,22 @@ def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> Non
|
|||
code_entity["full_name"] = getattr(result, "full_name", None)
|
||||
code_entity["module_name"] = getattr(result, "module_name", None)
|
||||
code_entity["module_path"] = getattr(result, "module_path", None)
|
||||
except KeyError as e:
|
||||
# TODO: See if there is a way to handle KeyError properly
|
||||
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
|
||||
return
|
||||
except UnpicklingError as e:
|
||||
# TODO: See if there is a way to handle UnpicklingError properly
|
||||
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
|
||||
return
|
||||
except EOFError as e:
|
||||
# TODO: See if there is a way to handle EOFError properly
|
||||
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
# logging.warning(f"Failed to analyze code entity {code_entity['name']}: {e}")
|
||||
logger.error(f"Failed to analyze code entity {code_entity['name']}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def _extract_dependencies(script_path: str) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.data.models import GraphMetrics
|
||||
import uuid
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
||||
async def fetch_token_count(db_engine) -> int:
|
||||
"""
|
||||
Fetches and sums token counts from the database.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens across all documents.
|
||||
"""
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
token_count_sum = await session.execute(select(func.sum(Data.token_count)))
|
||||
token_count_sum = token_count_sum.scalar()
|
||||
|
||||
return token_count_sum
|
||||
|
||||
|
||||
async def calculate_graph_metrics(graph_data):
|
||||
nodes, edges = graph_data
|
||||
graph_metrics = {
|
||||
"num_nodes": len(nodes),
|
||||
"num_edges": len(edges),
|
||||
}
|
||||
return graph_metrics
|
||||
|
||||
|
||||
async def store_descriptive_metrics(data_points: list[DataPoint]):
|
||||
db_engine = get_relational_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
|
||||
token_count_sum = await fetch_token_count(db_engine)
|
||||
graph_metrics = await calculate_graph_metrics(graph_data)
|
||||
|
||||
table_name = "graph_metrics_table"
|
||||
metrics_dict = {"id": uuid.uuid4(), "num_tokens": token_count_sum} | graph_metrics
|
||||
|
||||
await db_engine.insert_data(table_name, metrics_dict)
|
||||
return data_points
|
||||
|
|
@ -2,8 +2,8 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.modules.search.types import SearchType
|
||||
# from cognee.shared.utils import render_graph
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pathlib
|
|||
import cognee
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from swebench.harness.utils import load_swebench_dataset
|
|||
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
|
||||
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
|
||||
|
|
|
|||
|
|
@ -94,9 +94,7 @@ async def cognify_search_base_rag(content: str, context: str):
|
|||
async def cognify_search_graph(content: str, context: str):
|
||||
from cognee.api.v1.search import search, SearchType
|
||||
|
||||
params = {"query": "Donald Trump"}
|
||||
|
||||
results = await search(SearchType.INSIGHTS, params)
|
||||
results = await search(SearchType.INSIGHTS, query_text="Donald Trump")
|
||||
print("results", results)
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import asyncio
|
|||
import cognee
|
||||
import signal
|
||||
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import setup_logging
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
app = modal.App("cognee-runner")
|
||||
|
||||
|
|
|
|||
8
poetry.lock
generated
8
poetry.lock
generated
|
|
@ -1071,7 +1071,6 @@ files = [
|
|||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"},
|
||||
|
|
@ -1082,7 +1081,6 @@ files = [
|
|||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"},
|
||||
|
|
@ -3095,6 +3093,8 @@ optional = false
|
|||
python-versions = "*"
|
||||
files = [
|
||||
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
|
||||
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
|
||||
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -3694,13 +3694,17 @@ proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "
|
|||
|
||||
[[package]]
|
||||
name = "llama-index-core"
|
||||
|
||||
version = "0.12.14"
|
||||
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
|
||||
{file = "llama_index_core-0.12.14-py3-none-any.whl", hash = "sha256:6fdb30e3fadf98e7df75f9db5d06f6a7f8503ca545a71e048d786ff88012bd50"},
|
||||
{file = "llama_index_core-0.12.14.tar.gz", hash = "sha256:378bbf5bf4d1a8c692d3a980c1a6ed3be7a9afb676a4960429dea15f62d06cd3"},
|
||||
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "cognee"
|
||||
version = "0.1.23"
|
||||
version = "0.1.24"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = ["Vasilije Markovic", "Boris Arzentar"]
|
||||
readme = "README.md"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue