From 4d3acc358a134d7ccc232160ae343dc15d73f37d Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Tue, 4 Feb 2025 08:47:31 +0100 Subject: [PATCH] fix: mcp improvements (#472) ## Description ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **Dependency Update** - Downgraded `mcp` package version from 1.2.0 to 1.1.3 - Updated `cognee` dependency to include additional features with `cognee[codegraph]` - **New Features** - Introduced a new tool, "codify", for transforming codebases into knowledge graphs - Enhanced the existing "search" tool to accept a new parameter for search type - **Improvements** - Streamlined search functionality with a new modular approach - Added new asynchronous function for retrieving and formatting code parts - **Documentation** - Updated import paths for `SearchType` in various modules and tests to reflect structural changes - **Code Cleanup** - Removed legacy search module and associated classes/functions - Refined data transfer object classes for consistency and clarity --------- Co-authored-by: Boris Arzentar --- README.md | 2 +- cognee-mcp/pyproject.toml | 4 +- cognee-mcp/src/server.py | 46 ++++++-- cognee-mcp/uv.lock | 90 +++++++--------- .../routers/get_code_pipeline_router.py | 18 ++-- .../v1/search/routers/get_search_router.py | 2 +- cognee/api/v1/search/search.legacy.py | 100 ------------------ cognee/api/v1/search/search_v2.py | 64 +---------- .../modules/retrieval/code_graph_retrieval.py | 18 ++++ .../description_to_codepart_search.py | 10 +- cognee/modules/search/methods/__init__.py | 1 + cognee/modules/search/methods/search.py | 66 ++++++++++++ cognee/modules/search/types/SearchType.py | 10 ++ cognee/modules/search/types/__init__.py | 1 + cognee/shared/CodeGraphEntities.py | 1 + .../repo_processor/expand_dependency_graph.py | 1 + cognee/tests/test_falkordb.py | 4 +- cognee/tests/test_library.py | 2 +- cognee/tests/test_milvus.py | 2 +- cognee/tests/test_neo4j.py | 2 +- cognee/tests/test_pgvector.py | 2 +- cognee/tests/test_qdrant.py | 2 +- cognee/tests/test_weaviate.py | 2 +- evals/eval_swe_bench.py | 1 - evals/qa_context_provider_utils.py | 2 +- evals/simple_rag_vs_cognee_eval.py | 4 +- modal_deployment.py | 2 +- poetry.lock | 8 +- 28 files changed, 212 insertions(+), 255 deletions(-) delete mode 100644 cognee/api/v1/search/search.legacy.py create mode 100644 cognee/modules/retrieval/code_graph_retrieval.py create mode 100644 cognee/modules/search/methods/__init__.py create mode 100644 cognee/modules/search/methods/search.py create mode 100644 cognee/modules/search/types/SearchType.py create mode 100644 cognee/modules/search/types/__init__.py diff --git a/README.md b/README.md index abf854b24..07cf51859 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ This script will run the default pipeline: ```python import cognee import asyncio -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType async def main(): # Create a clean slate for cognee -- reset data and system state diff --git a/cognee-mcp/pyproject.toml b/cognee-mcp/pyproject.toml index 5b7ea57e8..03c353f0f 100644 --- a/cognee-mcp/pyproject.toml +++ b/cognee-mcp/pyproject.toml @@ -6,8 +6,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "cognee", - "mcp==1.2.0", + "cognee[codegraph]", + "mcp==1.1.3", ] [[project.authors]] diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 4cc7440f8..ec0a6564b 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -1,4 +1,5 @@ import os + import cognee import logging import importlib.util @@ -8,7 +9,8 @@ from contextlib import redirect_stderr, redirect_stdout import mcp.types as types from mcp.server import Server, NotificationOptions from mcp.server.models import InitializationOptions -from cognee.api.v1.search import SearchType +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline +from cognee.modules.search.types import SearchType from cognee.shared.data_models import KnowledgeGraph mcp = Server("cognee") @@ -41,6 +43,19 @@ async def list_tools() -> list[types.Tool]: "required": ["text"], }, ), + types.Tool( + name="codify", + description="Transforms codebase into knowledge graph", + inputSchema={ + "type": "object", + "properties": { + "repo_path": { + "type": "string", + }, + }, + "required": ["repo_path"], + }, + ), types.Tool( name="search", description="Searches for information in knowledge graph", @@ -51,6 +66,10 @@ async def list_tools() -> list[types.Tool]: "type": "string", "description": "The query to search for", }, + "search_type": { + "type": "string", + "description": "The type of search to perform (e.g., INSIGHTS, CODE)", + }, }, "required": ["search_query"], }, @@ -72,15 +91,21 @@ async def call_tools(name: str, arguments: dict) -> list[types.TextContent]: with open(os.devnull, "w") as fnull: with redirect_stdout(fnull), redirect_stderr(fnull): if name == "cognify": - await cognify( + cognify( text=arguments["text"], graph_model_file=arguments.get("graph_model_file", None), graph_model_name=arguments.get("graph_model_name", None), ) return [types.TextContent(type="text", text="Ingested")] + if name == "codify": + await codify(arguments.get("repo_path")) + + return [types.TextContent(type="text", text="Indexed")] elif name == "search": - search_results = await search(arguments["search_query"]) + search_results = await search( + arguments["search_query"], arguments["search_type"] + ) return [types.TextContent(type="text", text=search_results)] elif name == "prune": @@ -102,21 +127,28 @@ async def cognify(text: str, graph_model_file: str = None, graph_model_name: str await cognee.add(text) try: - await cognee.cognify(graph_model=graph_model) + asyncio.create_task(cognee.cognify(graph_model=graph_model)) except Exception as e: raise ValueError(f"Failed to cognify: {str(e)}") -async def search(search_query: str) -> str: +async def codify(repo_path: str): + async for result in run_code_graph_pipeline(repo_path, False): + print(result) + + +async def search(search_query: str, search_type: str) -> str: """Search the knowledge graph""" - search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query) + search_results = await cognee.search( + query_type=SearchType[search_type.upper()], query_text=search_query + ) results = retrieved_edges_to_string(search_results) return results -async def prune() -> str: +async def prune(): """Reset the knowledge graph""" await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee-mcp/uv.lock b/cognee-mcp/uv.lock index e467727f8..cdc363168 100644 --- a/cognee-mcp/uv.lock +++ b/cognee-mcp/uv.lock @@ -474,7 +474,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.1.22" +version = "0.1.23" source = { directory = "../" } dependencies = [ { name = "aiofiles" }, @@ -516,11 +516,16 @@ dependencies = [ { name = "sqlalchemy" }, { name = "tenacity" }, { name = "tiktoken" }, - { name = "transformers" }, { name = "typing-extensions" }, { name = "uvicorn" }, ] +[package.optional-dependencies] +codegraph = [ + { name = "jedi" }, + { name = "parso" }, +] + [package.metadata] requires-dist = [ { name = "aiofiles", specifier = ">=23.2.1,<24.0.0" }, @@ -539,6 +544,7 @@ requires-dist = [ { name = "fastapi", specifier = "==0.115.7" }, { name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" }, { name = "filetype", specifier = ">=1.2.0,<2.0.0" }, + { name = "google-generativeai", marker = "extra == 'gemini'", specifier = ">=0.8.4,<0.9.0" }, { name = "graphistry", specifier = ">=0.33.5,<0.34.0" }, { name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" }, { name = "gunicorn", specifier = ">=20.1.0,<21.0.0" }, @@ -578,7 +584,7 @@ requires-dist = [ { name = "sqlalchemy", specifier = "==2.0.36" }, { name = "tenacity", specifier = ">=9.0.0,<10.0.0" }, { name = "tiktoken", specifier = "==0.7.0" }, - { name = "transformers", specifier = ">=4.46.3,<5.0.0" }, + { name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5.0.0" }, { name = "typing-extensions", specifier = "==4.12.2" }, { name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" }, { name = "uvicorn", specifier = "==0.34.0" }, @@ -590,14 +596,14 @@ name = "cognee-mcp" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "cognee" }, + { name = "cognee", extra = ["codegraph"] }, { name = "mcp" }, ] [package.metadata] requires-dist = [ - { name = "cognee", directory = "../" }, - { name = "mcp", specifier = "==1.2.0" }, + { name = "cognee", extras = ["codegraph"], directory = "../" }, + { name = "mcp", specifier = "==1.1.3" }, ] [[package]] @@ -1307,6 +1313,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c5/82/fd319382c1a33d7021cf151007b4cbd5daddf09d9ca5fb670e476668f9fc/instructor-1.7.2-py3-none-any.whl", hash = "sha256:cb43d27f6d7631c31762b936b2fcb44d2a3f9d8a020430a0f4d3484604ffb95b", size = 71353 }, ] +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, +] + [[package]] name = "jinja2" version = "3.1.5" @@ -1739,21 +1757,19 @@ wheels = [ [[package]] name = "mcp" -version = "1.2.0" +version = "1.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, - { name = "pydantic-settings" }, { name = "sse-starlette" }, { name = "starlette" }, - { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/a5/b08dc846ebedae9f17ced878e6975826e90e448cd4592f532f6a88a925a7/mcp-1.2.0.tar.gz", hash = "sha256:2b06c7ece98d6ea9e6379caa38d74b432385c338fb530cb82e2c70ea7add94f5", size = 102973 } +sdist = { url = "https://files.pythonhosted.org/packages/f7/60/66ebfd280b197f9a9d074c9e46cb1ac3186a32d12e6bd0425c24fe7cf7e8/mcp-1.1.3.tar.gz", hash = "sha256:af11018b8e9153cdd25f3722ec639fe7a462c00213a330fd6f593968341a9883", size = 57903 } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/84/fca78f19ac8ce6c53ba416247c71baa53a9e791e98d3c81edbc20a77d6d1/mcp-1.2.0-py3-none-any.whl", hash = "sha256:1d0e77d8c14955a5aea1f5aa1f444c8e531c09355c829b20e42f7a142bc0755f", size = 66468 }, + { url = "https://files.pythonhosted.org/packages/b8/08/cfcfa13e41f8d27503c51a8cbf1939d720073ace92469d08655bb5de1b24/mcp-1.1.3-py3-none-any.whl", hash = "sha256:71462d6cd7c06c14689dfcf110ff22286ba1b608cfc3515c0a5cbe33d131731a", size = 36997 }, ] [[package]] @@ -2083,6 +2099,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 }, ] +[[package]] +name = "parso" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, +] + [[package]] name = "pathvalidate" version = "3.2.3" @@ -2850,28 +2875,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/ce/22673f4a85ccc640735b4f8d12178a0f41b5d3c6eda7f33756d10ce56901/s3transfer-0.11.1-py3-none-any.whl", hash = "sha256:8fa0aa48177be1f3425176dfe1ab85dcd3d962df603c3dbfc585e6bf857ef0ff", size = 84111 }, ] -[[package]] -name = "safetensors" -version = "0.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f4/4f/2ef9ef1766f8c194b01b67a63a444d2e557c8fe1d82faf3ebd85f370a917/safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8", size = 66957 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/d1/017e31e75e274492a11a456a9e7c171f8f7911fe50735b4ec6ff37221220/safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2", size = 427067 }, - { url = "https://files.pythonhosted.org/packages/24/84/e9d3ff57ae50dd0028f301c9ee064e5087fe8b00e55696677a0413c377a7/safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae", size = 408856 }, - { url = "https://files.pythonhosted.org/packages/f1/1d/fe95f5dd73db16757b11915e8a5106337663182d0381811c81993e0014a9/safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c", size = 450088 }, - { url = "https://files.pythonhosted.org/packages/cf/21/e527961b12d5ab528c6e47b92d5f57f33563c28a972750b238b871924e49/safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e", size = 458966 }, - { url = "https://files.pythonhosted.org/packages/a5/8b/1a037d7a57f86837c0b41905040369aea7d8ca1ec4b2a77592372b2ec380/safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869", size = 509915 }, - { url = "https://files.pythonhosted.org/packages/61/3d/03dd5cfd33839df0ee3f4581a20bd09c40246d169c0e4518f20b21d5f077/safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a", size = 527664 }, - { url = "https://files.pythonhosted.org/packages/c5/dc/8952caafa9a10a3c0f40fa86bacf3190ae7f55fa5eef87415b97b29cb97f/safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5", size = 461978 }, - { url = "https://files.pythonhosted.org/packages/60/da/82de1fcf1194e3dbefd4faa92dc98b33c06bed5d67890e0962dd98e18287/safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975", size = 491253 }, - { url = "https://files.pythonhosted.org/packages/5a/9a/d90e273c25f90c3ba1b0196a972003786f04c39e302fbd6649325b1272bb/safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e", size = 628644 }, - { url = "https://files.pythonhosted.org/packages/70/3c/acb23e05aa34b4f5edd2e7f393f8e6480fbccd10601ab42cd03a57d4ab5f/safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f", size = 721648 }, - { url = "https://files.pythonhosted.org/packages/71/45/eaa3dba5253a7c6931230dc961641455710ab231f8a89cb3c4c2af70f8c8/safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf", size = 659588 }, - { url = "https://files.pythonhosted.org/packages/b0/71/2f9851164f821064d43b481ddbea0149c2d676c4f4e077b178e7eeaa6660/safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76", size = 632533 }, - { url = "https://files.pythonhosted.org/packages/00/f1/5680e2ef61d9c61454fad82c344f0e40b8741a9dbd1e31484f0d31a9b1c3/safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2", size = 291167 }, - { url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 }, -] - [[package]] name = "scikit-learn" version = "1.6.1" @@ -3348,27 +3351,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 }, ] -[[package]] -name = "transformers" -version = "4.48.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock" }, - { name = "huggingface-hub" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pyyaml" }, - { name = "regex" }, - { name = "requests" }, - { name = "safetensors" }, - { name = "tokenizers" }, - { name = "tqdm" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/21/6b/caf620fae7fbf35947c81e7dd0834493b9ad9b71bb9e433025ac7a07e79a/transformers-4.48.1.tar.gz", hash = "sha256:7c1931facc3ee8adcbf86fc7a87461d54c1e40eca3bb57fef1ee9f3ecd32187e", size = 8365872 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl", hash = "sha256:24be0564b0a36d9e433d9a65de248f1545b6f6edce1737669605eb6a8141bbbb", size = 9665001 }, -] - [[package]] name = "typer" version = "0.15.1" diff --git a/cognee/api/v1/cognify/routers/get_code_pipeline_router.py b/cognee/api/v1/cognify/routers/get_code_pipeline_router.py index c4d436ce1..afda10e77 100644 --- a/cognee/api/v1/cognify/routers/get_code_pipeline_router.py +++ b/cognee/api/v1/cognify/routers/get_code_pipeline_router.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from pydantic import BaseModel +from cognee.api.DTO import InDTO from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.modules.retrieval.description_to_codepart_search import ( code_description_to_code_part_search, @@ -7,14 +7,14 @@ from cognee.modules.retrieval.description_to_codepart_search import ( from fastapi.responses import JSONResponse -class CodePipelineIndexPayloadDTO(BaseModel): +class CodePipelineIndexPayloadDTO(InDTO): repo_path: str include_docs: bool = False -class CodePipelineRetrievePayloadDTO(BaseModel): +class CodePipelineRetrievePayloadDTO(InDTO): query: str - fullInput: str + full_input: str def get_code_pipeline_router() -> APIRouter: @@ -34,9 +34,9 @@ def get_code_pipeline_router() -> APIRouter: """This endpoint is responsible for retrieving the context.""" try: query = ( - payload.fullInput.replace("cognee ", "") - if payload.fullInput.startswith("cognee ") - else payload.fullInput + payload.full_input.replace("cognee ", "") + if payload.full_input.startswith("cognee ") + else payload.full_input ) retrieved_codeparts, __ = await code_description_to_code_part_search( @@ -45,8 +45,8 @@ def get_code_pipeline_router() -> APIRouter: return [ { - "name": codepart.attributes["id"], - "description": codepart.attributes["id"], + "name": codepart.attributes["file_path"], + "description": codepart.attributes["file_path"], "content": codepart.attributes["source_code"], } for codepart in retrieved_codeparts diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index a97e84cf4..8d8d62f98 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -2,7 +2,7 @@ from uuid import UUID from datetime import datetime from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.models import User from cognee.modules.search.operations import get_history diff --git a/cognee/api/v1/search/search.legacy.py b/cognee/api/v1/search/search.legacy.py deleted file mode 100644 index c4e490f01..000000000 --- a/cognee/api/v1/search/search.legacy.py +++ /dev/null @@ -1,100 +0,0 @@ -"""This module contains the search function that is used to search for nodes in the graph.""" - -import asyncio -from enum import Enum -from typing import Dict, Any, Callable, List -from pydantic import BaseModel, field_validator - -from cognee.modules.search.graph import search_cypher -from cognee.modules.search.graph.search_adjacent import search_adjacent -from cognee.modules.search.vector.search_traverse import search_traverse -from cognee.modules.search.graph.search_summary import search_summary -from cognee.modules.search.graph.search_similarity import search_similarity - -from cognee.exceptions import UserNotFoundError -from cognee.shared.utils import send_telemetry -from cognee.modules.users.permissions.methods import get_document_ids_for_user -from cognee.modules.users.methods import get_default_user -from cognee.modules.users.models import User - - -class SearchType(Enum): - ADJACENT = "ADJACENT" - TRAVERSE = "TRAVERSE" - SIMILARITY = "SIMILARITY" - SUMMARY = "SUMMARY" - SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION" - NODE_CLASSIFICATION = "NODE_CLASSIFICATION" - DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",) - CYPHER = "CYPHER" - - @staticmethod - def from_str(name: str): - try: - return SearchType[name.upper()] - except KeyError as error: - raise ValueError(f"{name} is not a valid SearchType") from error - - -class SearchParameters(BaseModel): - search_type: SearchType - params: Dict[str, Any] - - @field_validator("search_type", mode="before") - def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument - if isinstance(value, str): - return SearchType.from_str(value) - return value - - -async def search(search_type: str, params: Dict[str, Any], user: User = None) -> List: - if user is None: - user = await get_default_user() - - if user is None: - raise UserNotFoundError - - own_document_ids = await get_document_ids_for_user(user.id) - search_params = SearchParameters(search_type=search_type, params=params) - search_results = await specific_search([search_params], user) - - from uuid import UUID - - filtered_search_results = [] - - for search_result in search_results: - document_id = search_result["document_id"] if "document_id" in search_result else None - document_id = UUID(document_id) if isinstance(document_id, str) else document_id - - if document_id is None or document_id in own_document_ids: - filtered_search_results.append(search_result) - - return filtered_search_results - - -async def specific_search(query_params: List[SearchParameters], user) -> List: - search_functions: Dict[SearchType, Callable] = { - SearchType.ADJACENT: search_adjacent, - SearchType.SUMMARY: search_summary, - SearchType.CYPHER: search_cypher, - SearchType.TRAVERSE: search_traverse, - SearchType.SIMILARITY: search_similarity, - } - - search_tasks = [] - - send_telemetry("cognee.search EXECUTION STARTED", user.id) - - for search_param in query_params: - search_func = search_functions.get(search_param.search_type) - if search_func: - # Schedule the coroutine for execution and store the task - task = search_func(**search_param.params) - search_tasks.append(task) - - # Use asyncio.gather to run all scheduled tasks concurrently - search_results = await asyncio.gather(*search_tasks) - - send_telemetry("cognee.search EXECUTION COMPLETED", user.id) - - return search_results[0] if len(search_results) == 1 else search_results diff --git a/cognee/api/v1/search/search_v2.py b/cognee/api/v1/search/search_v2.py index 4166fd3f3..e187181d5 100644 --- a/cognee/api/v1/search/search_v2.py +++ b/cognee/api/v1/search/search_v2.py @@ -1,29 +1,10 @@ -import json -from uuid import UUID -from enum import Enum -from typing import Callable, Dict, Union +from typing import Union -from cognee.exceptions import InvalidValueError -from cognee.modules.search.operations import log_query, log_result -from cognee.modules.storage.utils import JSONEncoder -from cognee.shared.utils import send_telemetry +from cognee.modules.search.types import SearchType from cognee.modules.users.exceptions import UserNotFoundError from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user -from cognee.modules.users.permissions.methods import get_document_ids_for_user -from cognee.tasks.chunks import query_chunks -from cognee.tasks.graph import query_graph_connections -from cognee.tasks.summarization import query_summaries -from cognee.tasks.completion import query_completion -from cognee.tasks.completion import graph_query_completion - - -class SearchType(Enum): - SUMMARIES = "SUMMARIES" - INSIGHTS = "INSIGHTS" - CHUNKS = "CHUNKS" - COMPLETION = "COMPLETION" - GRAPH_COMPLETION = "GRAPH_COMPLETION" +from cognee.modules.search.methods import search as search_function async def search( @@ -42,43 +23,6 @@ async def search( if user is None: raise UserNotFoundError - query = await log_query(query_text, str(query_type), user.id) - - own_document_ids = await get_document_ids_for_user(user.id, datasets) - search_results = await specific_search(query_type, query_text, user) - - filtered_search_results = [] - - for search_result in search_results: - document_id = search_result["document_id"] if "document_id" in search_result else None - document_id = UUID(document_id) if isinstance(document_id, str) else document_id - - if document_id is None or document_id in own_document_ids: - filtered_search_results.append(search_result) - - await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id) + filtered_search_results = await search_function(query_text, query_type, datasets, user) return filtered_search_results - - -async def specific_search(query_type: SearchType, query: str, user) -> list: - search_tasks: Dict[SearchType, Callable] = { - SearchType.SUMMARIES: query_summaries, - SearchType.INSIGHTS: query_graph_connections, - SearchType.CHUNKS: query_chunks, - SearchType.COMPLETION: query_completion, - SearchType.GRAPH_COMPLETION: graph_query_completion, - } - - search_task = search_tasks.get(query_type) - - if search_task is None: - raise InvalidValueError(message=f"Unsupported search type: {query_type}") - - send_telemetry("cognee.search EXECUTION STARTED", user.id) - - results = await search_task(query) - - send_telemetry("cognee.search EXECUTION COMPLETED", user.id) - - return results diff --git a/cognee/modules/retrieval/code_graph_retrieval.py b/cognee/modules/retrieval/code_graph_retrieval.py new file mode 100644 index 000000000..96e48d388 --- /dev/null +++ b/cognee/modules/retrieval/code_graph_retrieval.py @@ -0,0 +1,18 @@ +from cognee.modules.retrieval.description_to_codepart_search import ( + code_description_to_code_part_search, +) + + +async def code_graph_retrieval(query, include_docs=False): + retrieved_codeparts, __ = await code_description_to_code_part_search( + query, include_docs=include_docs + ) + + return [ + { + "name": codepart.attributes["file_path"], + "description": codepart.attributes["file_path"], + "content": codepart.attributes["source_code"], + } + for codepart in retrieved_codeparts + ] diff --git a/cognee/modules/retrieval/description_to_codepart_search.py b/cognee/modules/retrieval/description_to_codepart_search.py index 3e0728a3d..acc289d0f 100644 --- a/cognee/modules/retrieval/description_to_codepart_search.py +++ b/cognee/modules/retrieval/description_to_codepart_search.py @@ -1,15 +1,14 @@ import asyncio import logging -from typing import Set, List +from typing import List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User from cognee.shared.utils import send_telemetry -from cognee.api.v1.search import SearchType -from cognee.api.v1.search.search_v2 import search +from cognee.modules.search.methods import search from cognee.infrastructure.llm.get_llm_client import get_llm_client @@ -39,7 +38,7 @@ async def code_description_to_code_part( include_docs(bool): Boolean showing whether we have the docs in the graph or not Returns: - Set[str]: A set of unique code parts matching the query. + List[str]: A set of unique code parts matching the query. Raises: ValueError: If arguments are invalid. @@ -66,7 +65,7 @@ async def code_description_to_code_part( try: if include_docs: - search_results = await search(SearchType.INSIGHTS, query_text=query) + search_results = await search(query_text=query, query_type="INSIGHTS", user=user) concatenated_descriptions = " ".join( obj["description"] @@ -98,6 +97,7 @@ async def code_description_to_code_part( "id", "type", "text", + "file_path", "source_code", "pydantic_type", ], diff --git a/cognee/modules/search/methods/__init__.py b/cognee/modules/search/methods/__init__.py new file mode 100644 index 000000000..005c520d1 --- /dev/null +++ b/cognee/modules/search/methods/__init__.py @@ -0,0 +1 @@ +from .search import search diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py new file mode 100644 index 000000000..ede465d11 --- /dev/null +++ b/cognee/modules/search/methods/search.py @@ -0,0 +1,66 @@ +import json +from uuid import UUID +from typing import Callable + +from cognee.exceptions import InvalidValueError +from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval +from cognee.modules.search.types import SearchType +from cognee.modules.storage.utils import JSONEncoder +from cognee.modules.users.models import User +from cognee.modules.users.permissions.methods import get_document_ids_for_user +from cognee.shared.utils import send_telemetry +from cognee.tasks.chunks import query_chunks +from cognee.tasks.graph import query_graph_connections +from cognee.tasks.summarization import query_summaries +from cognee.tasks.completion import query_completion +from cognee.tasks.completion import graph_query_completion +from ..operations import log_query, log_result + + +async def search( + query_text: str, + query_type: str, + datasets: list[str], + user: User, +): + query = await log_query(query_text, str(query_type), user.id) + + own_document_ids = await get_document_ids_for_user(user.id, datasets) + search_results = await specific_search(query_type, query_text, user) + + filtered_search_results = [] + + for search_result in search_results: + document_id = search_result["document_id"] if "document_id" in search_result else None + document_id = UUID(document_id) if isinstance(document_id, str) else document_id + + if document_id is None or document_id in own_document_ids: + filtered_search_results.append(search_result) + + await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id) + + return filtered_search_results + + +async def specific_search(query_type: SearchType, query: str, user: User) -> list: + search_tasks: dict[SearchType, Callable] = { + SearchType.SUMMARIES: query_summaries, + SearchType.INSIGHTS: query_graph_connections, + SearchType.CHUNKS: query_chunks, + SearchType.COMPLETION: query_completion, + SearchType.GRAPH_COMPLETION: graph_query_completion, + SearchType.CODE: code_graph_retrieval, + } + + search_task = search_tasks.get(query_type) + + if search_task is None: + raise InvalidValueError(message=f"Unsupported search type: {query_type}") + + send_telemetry("cognee.search EXECUTION STARTED", user.id) + + results = await search_task(query) + + send_telemetry("cognee.search EXECUTION COMPLETED", user.id) + + return results diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py new file mode 100644 index 000000000..b01db2237 --- /dev/null +++ b/cognee/modules/search/types/SearchType.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class SearchType(Enum): + SUMMARIES = "SUMMARIES" + INSIGHTS = "INSIGHTS" + CHUNKS = "CHUNKS" + COMPLETION = "COMPLETION" + GRAPH_COMPLETION = "GRAPH_COMPLETION" + CODE = "CODE" diff --git a/cognee/modules/search/types/__init__.py b/cognee/modules/search/types/__init__.py new file mode 100644 index 000000000..62b49fa74 --- /dev/null +++ b/cognee/modules/search/types/__init__.py @@ -0,0 +1 @@ +from .SearchType import SearchType diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py index 926aae9fa..5c5f65b07 100644 --- a/cognee/shared/CodeGraphEntities.py +++ b/cognee/shared/CodeGraphEntities.py @@ -23,6 +23,7 @@ class CodeFile(DataPoint): class CodePart(DataPoint): __tablename__ = "codepart" + file_path: str # file path # part_of: Optional[CodeFile] = None pydantic_type: str = "CodePart" source_code: Optional[str] = None diff --git a/cognee/tasks/repo_processor/expand_dependency_graph.py b/cognee/tasks/repo_processor/expand_dependency_graph.py index cc957742b..9a237e890 100644 --- a/cognee/tasks/repo_processor/expand_dependency_graph.py +++ b/cognee/tasks/repo_processor/expand_dependency_graph.py @@ -30,6 +30,7 @@ def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts) id=part_node_id, type=part_type, # part_of = code_file, + file_path=code_file.extracted_id[len(code_file.part_of.path) + 1 :], source_code=code_part, ) ) diff --git a/cognee/tests/test_falkordb.py b/cognee/tests/test_falkordb.py index af0e87916..501c61af4 100755 --- a/cognee/tests/test_falkordb.py +++ b/cognee/tests/test_falkordb.py @@ -2,8 +2,8 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType -from cognee.shared.utils import render_graph +from cognee.modules.search.types import SearchType +# from cognee.shared.utils import render_graph logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 192b67506..cd78b144e 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -2,7 +2,7 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_milvus.py b/cognee/tests/test_milvus.py index da02ca936..bd16c04f0 100644 --- a/cognee/tests/test_milvus.py +++ b/cognee/tests/test_milvus.py @@ -2,7 +2,7 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 07274c010..bf93e2c52 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -2,7 +2,7 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 73b6be974..99b5ca724 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -4,7 +4,7 @@ import pathlib import cognee from cognee.modules.data.models import Data -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.users.methods import get_default_user diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 7f82a569f..16adc0494 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -2,7 +2,7 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 874f21347..01021e6fc 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -2,7 +2,7 @@ import os import logging import pathlib import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index b10eab3e2..c3178ce40 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -8,7 +8,6 @@ from swebench.harness.utils import load_swebench_dataset from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline -from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.modules.retrieval.description_to_codepart_search import ( diff --git a/evals/qa_context_provider_utils.py b/evals/qa_context_provider_utils.py index 2cef1e628..100ceb290 100644 --- a/evals/qa_context_provider_utils.py +++ b/evals/qa_context_provider_utils.py @@ -1,5 +1,5 @@ import cognee -from cognee.api.v1.search import SearchType +from cognee.modules.search.types import SearchType from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string diff --git a/evals/simple_rag_vs_cognee_eval.py b/evals/simple_rag_vs_cognee_eval.py index 82b1df600..ab4acbe53 100644 --- a/evals/simple_rag_vs_cognee_eval.py +++ b/evals/simple_rag_vs_cognee_eval.py @@ -94,9 +94,7 @@ async def cognify_search_base_rag(content: str, context: str): async def cognify_search_graph(content: str, context: str): from cognee.api.v1.search import search, SearchType - params = {"query": "Donald Trump"} - - results = await search(SearchType.INSIGHTS, params) + results = await search(SearchType.INSIGHTS, query_text="Donald Trump") print("results", results) return results diff --git a/modal_deployment.py b/modal_deployment.py index 5622f96e6..28e36d8ef 100644 --- a/modal_deployment.py +++ b/modal_deployment.py @@ -5,8 +5,8 @@ import asyncio import cognee import signal -from cognee.api.v1.search import SearchType from cognee.shared.utils import setup_logging +from cognee.modules.search.types import SearchType app = modal.App("cognee-runner") diff --git a/poetry.lock b/poetry.lock index ee4e41039..b12dbd9b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1071,7 +1071,6 @@ files = [ {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"}, - {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"}, {file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"}, @@ -1082,7 +1081,6 @@ files = [ {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"}, - {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"}, {file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"}, @@ -3095,6 +3093,8 @@ optional = false python-versions = "*" files = [ {file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"}, + {file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"}, + {file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"}, ] [package.dependencies] @@ -3694,13 +3694,17 @@ proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", " [[package]] name = "llama-index-core" + version = "0.12.14" + description = "Interface between LLMs and your data" optional = true python-versions = "<4.0,>=3.9" files = [ + {file = "llama_index_core-0.12.14-py3-none-any.whl", hash = "sha256:6fdb30e3fadf98e7df75f9db5d06f6a7f8503ca545a71e048d786ff88012bd50"}, {file = "llama_index_core-0.12.14.tar.gz", hash = "sha256:378bbf5bf4d1a8c692d3a980c1a6ed3be7a9afb676a4960429dea15f62d06cd3"}, + ] [package.dependencies]