fix: mcp improvements (#472)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Dependency Update**
	- Downgraded `mcp` package version from 1.2.0 to 1.1.3
- Updated `cognee` dependency to include additional features with
`cognee[codegraph]`

- **New Features**
- Introduced a new tool, "codify", for transforming codebases into
knowledge graphs
- Enhanced the existing "search" tool to accept a new parameter for
search type

- **Improvements**
	- Streamlined search functionality with a new modular approach
- Added new asynchronous function for retrieving and formatting code
parts

- **Documentation**
- Updated import paths for `SearchType` in various modules and tests to
reflect structural changes

- **Code Cleanup**
	- Removed legacy search module and associated classes/functions
	- Refined data transfer object classes for consistency and clarity
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
Vasilije 2025-02-04 08:47:31 +01:00 committed by GitHub
parent 2858a674f5
commit 4d3acc358a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 212 additions and 255 deletions

View file

@ -130,7 +130,7 @@ This script will run the default pipeline:
```python
import cognee
import asyncio
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
async def main():
# Create a clean slate for cognee -- reset data and system state

View file

@ -6,8 +6,8 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"cognee",
"mcp==1.2.0",
"cognee[codegraph]",
"mcp==1.1.3",
]
[[project.authors]]

View file

@ -1,4 +1,5 @@
import os
import cognee
import logging
import importlib.util
@ -8,7 +9,8 @@ from contextlib import redirect_stderr, redirect_stdout
import mcp.types as types
from mcp.server import Server, NotificationOptions
from mcp.server.models import InitializationOptions
from cognee.api.v1.search import SearchType
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
mcp = Server("cognee")
@ -41,6 +43,19 @@ async def list_tools() -> list[types.Tool]:
"required": ["text"],
},
),
types.Tool(
name="codify",
description="Transforms codebase into knowledge graph",
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
},
},
"required": ["repo_path"],
},
),
types.Tool(
name="search",
description="Searches for information in knowledge graph",
@ -51,6 +66,10 @@ async def list_tools() -> list[types.Tool]:
"type": "string",
"description": "The query to search for",
},
"search_type": {
"type": "string",
"description": "The type of search to perform (e.g., INSIGHTS, CODE)",
},
},
"required": ["search_query"],
},
@ -72,15 +91,21 @@ async def call_tools(name: str, arguments: dict) -> list[types.TextContent]:
with open(os.devnull, "w") as fnull:
with redirect_stdout(fnull), redirect_stderr(fnull):
if name == "cognify":
await cognify(
cognify(
text=arguments["text"],
graph_model_file=arguments.get("graph_model_file", None),
graph_model_name=arguments.get("graph_model_name", None),
)
return [types.TextContent(type="text", text="Ingested")]
if name == "codify":
await codify(arguments.get("repo_path"))
return [types.TextContent(type="text", text="Indexed")]
elif name == "search":
search_results = await search(arguments["search_query"])
search_results = await search(
arguments["search_query"], arguments["search_type"]
)
return [types.TextContent(type="text", text=search_results)]
elif name == "prune":
@ -102,21 +127,28 @@ async def cognify(text: str, graph_model_file: str = None, graph_model_name: str
await cognee.add(text)
try:
await cognee.cognify(graph_model=graph_model)
asyncio.create_task(cognee.cognify(graph_model=graph_model))
except Exception as e:
raise ValueError(f"Failed to cognify: {str(e)}")
async def search(search_query: str) -> str:
async def codify(repo_path: str):
async for result in run_code_graph_pipeline(repo_path, False):
print(result)
async def search(search_query: str, search_type: str) -> str:
"""Search the knowledge graph"""
search_results = await cognee.search(SearchType.INSIGHTS, query_text=search_query)
search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
)
results = retrieved_edges_to_string(search_results)
return results
async def prune() -> str:
async def prune():
"""Reset the knowledge graph"""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

90
cognee-mcp/uv.lock generated
View file

@ -474,7 +474,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.1.22"
version = "0.1.23"
source = { directory = "../" }
dependencies = [
{ name = "aiofiles" },
@ -516,11 +516,16 @@ dependencies = [
{ name = "sqlalchemy" },
{ name = "tenacity" },
{ name = "tiktoken" },
{ name = "transformers" },
{ name = "typing-extensions" },
{ name = "uvicorn" },
]
[package.optional-dependencies]
codegraph = [
{ name = "jedi" },
{ name = "parso" },
]
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=23.2.1,<24.0.0" },
@ -539,6 +544,7 @@ requires-dist = [
{ name = "fastapi", specifier = "==0.115.7" },
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" },
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
{ name = "google-generativeai", marker = "extra == 'gemini'", specifier = ">=0.8.4,<0.9.0" },
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
{ name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" },
{ name = "gunicorn", specifier = ">=20.1.0,<21.0.0" },
@ -578,7 +584,7 @@ requires-dist = [
{ name = "sqlalchemy", specifier = "==2.0.36" },
{ name = "tenacity", specifier = ">=9.0.0,<10.0.0" },
{ name = "tiktoken", specifier = "==0.7.0" },
{ name = "transformers", specifier = ">=4.46.3,<5.0.0" },
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5.0.0" },
{ name = "typing-extensions", specifier = "==4.12.2" },
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" },
{ name = "uvicorn", specifier = "==0.34.0" },
@ -590,14 +596,14 @@ name = "cognee-mcp"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "cognee" },
{ name = "cognee", extra = ["codegraph"] },
{ name = "mcp" },
]
[package.metadata]
requires-dist = [
{ name = "cognee", directory = "../" },
{ name = "mcp", specifier = "==1.2.0" },
{ name = "cognee", extras = ["codegraph"], directory = "../" },
{ name = "mcp", specifier = "==1.1.3" },
]
[[package]]
@ -1307,6 +1313,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/82/fd319382c1a33d7021cf151007b4cbd5daddf09d9ca5fb670e476668f9fc/instructor-1.7.2-py3-none-any.whl", hash = "sha256:cb43d27f6d7631c31762b936b2fcb44d2a3f9d8a020430a0f4d3484604ffb95b", size = 71353 },
]
[[package]]
name = "jedi"
version = "0.19.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "parso" },
]
sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 },
]
[[package]]
name = "jinja2"
version = "3.1.5"
@ -1739,21 +1757,19 @@ wheels = [
[[package]]
name = "mcp"
version = "1.2.0"
version = "1.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "httpx" },
{ name = "httpx-sse" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "sse-starlette" },
{ name = "starlette" },
{ name = "uvicorn" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ab/a5/b08dc846ebedae9f17ced878e6975826e90e448cd4592f532f6a88a925a7/mcp-1.2.0.tar.gz", hash = "sha256:2b06c7ece98d6ea9e6379caa38d74b432385c338fb530cb82e2c70ea7add94f5", size = 102973 }
sdist = { url = "https://files.pythonhosted.org/packages/f7/60/66ebfd280b197f9a9d074c9e46cb1ac3186a32d12e6bd0425c24fe7cf7e8/mcp-1.1.3.tar.gz", hash = "sha256:af11018b8e9153cdd25f3722ec639fe7a462c00213a330fd6f593968341a9883", size = 57903 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/af/84/fca78f19ac8ce6c53ba416247c71baa53a9e791e98d3c81edbc20a77d6d1/mcp-1.2.0-py3-none-any.whl", hash = "sha256:1d0e77d8c14955a5aea1f5aa1f444c8e531c09355c829b20e42f7a142bc0755f", size = 66468 },
{ url = "https://files.pythonhosted.org/packages/b8/08/cfcfa13e41f8d27503c51a8cbf1939d720073ace92469d08655bb5de1b24/mcp-1.1.3-py3-none-any.whl", hash = "sha256:71462d6cd7c06c14689dfcf110ff22286ba1b608cfc3515c0a5cbe33d131731a", size = 36997 },
]
[[package]]
@ -2083,6 +2099,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 },
]
[[package]]
name = "parso"
version = "0.8.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
]
[[package]]
name = "pathvalidate"
version = "3.2.3"
@ -2850,28 +2875,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5f/ce/22673f4a85ccc640735b4f8d12178a0f41b5d3c6eda7f33756d10ce56901/s3transfer-0.11.1-py3-none-any.whl", hash = "sha256:8fa0aa48177be1f3425176dfe1ab85dcd3d962df603c3dbfc585e6bf857ef0ff", size = 84111 },
]
[[package]]
name = "safetensors"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f4/4f/2ef9ef1766f8c194b01b67a63a444d2e557c8fe1d82faf3ebd85f370a917/safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8", size = 66957 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/96/d1/017e31e75e274492a11a456a9e7c171f8f7911fe50735b4ec6ff37221220/safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2", size = 427067 },
{ url = "https://files.pythonhosted.org/packages/24/84/e9d3ff57ae50dd0028f301c9ee064e5087fe8b00e55696677a0413c377a7/safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae", size = 408856 },
{ url = "https://files.pythonhosted.org/packages/f1/1d/fe95f5dd73db16757b11915e8a5106337663182d0381811c81993e0014a9/safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c", size = 450088 },
{ url = "https://files.pythonhosted.org/packages/cf/21/e527961b12d5ab528c6e47b92d5f57f33563c28a972750b238b871924e49/safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e", size = 458966 },
{ url = "https://files.pythonhosted.org/packages/a5/8b/1a037d7a57f86837c0b41905040369aea7d8ca1ec4b2a77592372b2ec380/safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869", size = 509915 },
{ url = "https://files.pythonhosted.org/packages/61/3d/03dd5cfd33839df0ee3f4581a20bd09c40246d169c0e4518f20b21d5f077/safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a", size = 527664 },
{ url = "https://files.pythonhosted.org/packages/c5/dc/8952caafa9a10a3c0f40fa86bacf3190ae7f55fa5eef87415b97b29cb97f/safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5", size = 461978 },
{ url = "https://files.pythonhosted.org/packages/60/da/82de1fcf1194e3dbefd4faa92dc98b33c06bed5d67890e0962dd98e18287/safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975", size = 491253 },
{ url = "https://files.pythonhosted.org/packages/5a/9a/d90e273c25f90c3ba1b0196a972003786f04c39e302fbd6649325b1272bb/safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e", size = 628644 },
{ url = "https://files.pythonhosted.org/packages/70/3c/acb23e05aa34b4f5edd2e7f393f8e6480fbccd10601ab42cd03a57d4ab5f/safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f", size = 721648 },
{ url = "https://files.pythonhosted.org/packages/71/45/eaa3dba5253a7c6931230dc961641455710ab231f8a89cb3c4c2af70f8c8/safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf", size = 659588 },
{ url = "https://files.pythonhosted.org/packages/b0/71/2f9851164f821064d43b481ddbea0149c2d676c4f4e077b178e7eeaa6660/safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76", size = 632533 },
{ url = "https://files.pythonhosted.org/packages/00/f1/5680e2ef61d9c61454fad82c344f0e40b8741a9dbd1e31484f0d31a9b1c3/safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2", size = 291167 },
{ url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
]
[[package]]
name = "scikit-learn"
version = "1.6.1"
@ -3348,27 +3351,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
]
[[package]]
name = "transformers"
version = "4.48.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "regex" },
{ name = "requests" },
{ name = "safetensors" },
{ name = "tokenizers" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/21/6b/caf620fae7fbf35947c81e7dd0834493b9ad9b71bb9e433025ac7a07e79a/transformers-4.48.1.tar.gz", hash = "sha256:7c1931facc3ee8adcbf86fc7a87461d54c1e40eca3bb57fef1ee9f3ecd32187e", size = 8365872 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl", hash = "sha256:24be0564b0a36d9e433d9a65de248f1545b6f6edce1737669605eb6a8141bbbb", size = 9665001 },
]
[[package]]
name = "typer"
version = "0.15.1"

View file

@ -1,5 +1,5 @@
from fastapi import APIRouter
from pydantic import BaseModel
from cognee.api.DTO import InDTO
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
@ -7,14 +7,14 @@ from cognee.modules.retrieval.description_to_codepart_search import (
from fastapi.responses import JSONResponse
class CodePipelineIndexPayloadDTO(BaseModel):
class CodePipelineIndexPayloadDTO(InDTO):
repo_path: str
include_docs: bool = False
class CodePipelineRetrievePayloadDTO(BaseModel):
class CodePipelineRetrievePayloadDTO(InDTO):
query: str
fullInput: str
full_input: str
def get_code_pipeline_router() -> APIRouter:
@ -34,9 +34,9 @@ def get_code_pipeline_router() -> APIRouter:
"""This endpoint is responsible for retrieving the context."""
try:
query = (
payload.fullInput.replace("cognee ", "")
if payload.fullInput.startswith("cognee ")
else payload.fullInput
payload.full_input.replace("cognee ", "")
if payload.full_input.startswith("cognee ")
else payload.full_input
)
retrieved_codeparts, __ = await code_description_to_code_part_search(
@ -45,8 +45,8 @@ def get_code_pipeline_router() -> APIRouter:
return [
{
"name": codepart.attributes["id"],
"description": codepart.attributes["id"],
"name": codepart.attributes["file_path"],
"description": codepart.attributes["file_path"],
"content": codepart.attributes["source_code"],
}
for codepart in retrieved_codeparts

View file

@ -2,7 +2,7 @@ from uuid import UUID
from datetime import datetime
from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history

View file

@ -1,100 +0,0 @@
"""This module contains the search function that is used to search for nodes in the graph."""
import asyncio
from enum import Enum
from typing import Dict, Any, Callable, List
from pydantic import BaseModel, field_validator
from cognee.modules.search.graph import search_cypher
from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.exceptions import UserNotFoundError
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
class SearchType(Enum):
ADJACENT = "ADJACENT"
TRAVERSE = "TRAVERSE"
SIMILARITY = "SIMILARITY"
SUMMARY = "SUMMARY"
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
DOCUMENT_CLASSIFICATION = ("DOCUMENT_CLASSIFICATION",)
CYPHER = "CYPHER"
@staticmethod
def from_str(name: str):
try:
return SearchType[name.upper()]
except KeyError as error:
raise ValueError(f"{name} is not a valid SearchType") from error
class SearchParameters(BaseModel):
search_type: SearchType
params: Dict[str, Any]
@field_validator("search_type", mode="before")
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
if isinstance(value, str):
return SearchType.from_str(value)
return value
async def search(search_type: str, params: Dict[str, Any], user: User = None) -> List:
if user is None:
user = await get_default_user()
if user is None:
raise UserNotFoundError
own_document_ids = await get_document_ids_for_user(user.id)
search_params = SearchParameters(search_type=search_type, params=params)
search_results = await specific_search([search_params], user)
from uuid import UUID
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
return filtered_search_results
async def specific_search(query_params: List[SearchParameters], user) -> List:
search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent,
SearchType.SUMMARY: search_summary,
SearchType.CYPHER: search_cypher,
SearchType.TRAVERSE: search_traverse,
SearchType.SIMILARITY: search_similarity,
}
search_tasks = []
send_telemetry("cognee.search EXECUTION STARTED", user.id)
for search_param in query_params:
search_func = search_functions.get(search_param.search_type)
if search_func:
# Schedule the coroutine for execution and store the task
task = search_func(**search_param.params)
search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently
search_results = await asyncio.gather(*search_tasks)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return search_results[0] if len(search_results) == 1 else search_results

View file

@ -1,29 +1,10 @@
import json
from uuid import UUID
from enum import Enum
from typing import Callable, Dict, Union
from typing import Union
from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.storage.utils import JSONEncoder
from cognee.shared.utils import send_telemetry
from cognee.modules.search.types import SearchType
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
from cognee.modules.search.methods import search as search_function
async def search(
@ -42,43 +23,6 @@ async def search(
if user is None:
raise UserNotFoundError
query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user)
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
filtered_search_results = await search_function(query_text, query_type, datasets, user)
return filtered_search_results
async def specific_search(query_type: SearchType, query: str, user) -> list:
search_tasks: Dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion,
}
search_task = search_tasks.get(query_type)
if search_task is None:
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results

View file

@ -0,0 +1,18 @@
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
)
async def code_graph_retrieval(query, include_docs=False):
retrieved_codeparts, __ = await code_description_to_code_part_search(
query, include_docs=include_docs
)
return [
{
"name": codepart.attributes["file_path"],
"description": codepart.attributes["file_path"],
"content": codepart.attributes["source_code"],
}
for codepart in retrieved_codeparts
]

View file

@ -1,15 +1,14 @@
import asyncio
import logging
from typing import Set, List
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.api.v1.search import SearchType
from cognee.api.v1.search.search_v2 import search
from cognee.modules.search.methods import search
from cognee.infrastructure.llm.get_llm_client import get_llm_client
@ -39,7 +38,7 @@ async def code_description_to_code_part(
include_docs(bool): Boolean showing whether we have the docs in the graph or not
Returns:
Set[str]: A set of unique code parts matching the query.
List[str]: A set of unique code parts matching the query.
Raises:
ValueError: If arguments are invalid.
@ -66,7 +65,7 @@ async def code_description_to_code_part(
try:
if include_docs:
search_results = await search(SearchType.INSIGHTS, query_text=query)
search_results = await search(query_text=query, query_type="INSIGHTS", user=user)
concatenated_descriptions = " ".join(
obj["description"]
@ -98,6 +97,7 @@ async def code_description_to_code_part(
"id",
"type",
"text",
"file_path",
"source_code",
"pydantic_type",
],

View file

@ -0,0 +1 @@
from .search import search

View file

@ -0,0 +1,66 @@
import json
from uuid import UUID
from typing import Callable
from cognee.exceptions import InvalidValueError
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
from ..operations import log_query, log_result
async def search(
query_text: str,
query_type: str,
datasets: list[str],
user: User,
):
query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user)
filtered_search_results = []
for search_result in search_results:
document_id = search_result["document_id"] if "document_id" in search_result else None
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
if document_id is None or document_id in own_document_ids:
filtered_search_results.append(search_result)
await log_result(query.id, json.dumps(filtered_search_results, cls=JSONEncoder), user.id)
return filtered_search_results
async def specific_search(query_type: SearchType, query: str, user: User) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion,
SearchType.CODE: code_graph_retrieval,
}
search_task = search_tasks.get(query_type)
if search_task is None:
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results

View file

@ -0,0 +1,10 @@
from enum import Enum
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
CODE = "CODE"

View file

@ -0,0 +1 @@
from .SearchType import SearchType

View file

@ -23,6 +23,7 @@ class CodeFile(DataPoint):
class CodePart(DataPoint):
__tablename__ = "codepart"
file_path: str # file path
# part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None

View file

@ -30,6 +30,7 @@ def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts)
id=part_node_id,
type=part_type,
# part_of = code_file,
file_path=code_file.extracted_id[len(code_file.part_of.path) + 1 :],
source_code=code_part,
)
)

View file

@ -2,8 +2,8 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.shared.utils import render_graph
from cognee.modules.search.types import SearchType
# from cognee.shared.utils import render_graph
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -4,7 +4,7 @@ import pathlib
import cognee
from cognee.modules.data.models import Data
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.users.methods import get_default_user

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -2,7 +2,7 @@ import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)

View file

@ -8,7 +8,6 @@ from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.description_to_codepart_search import (

View file

@ -1,5 +1,5 @@
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string

View file

@ -94,9 +94,7 @@ async def cognify_search_base_rag(content: str, context: str):
async def cognify_search_graph(content: str, context: str):
from cognee.api.v1.search import search, SearchType
params = {"query": "Donald Trump"}
results = await search(SearchType.INSIGHTS, params)
results = await search(SearchType.INSIGHTS, query_text="Donald Trump")
print("results", results)
return results

View file

@ -5,8 +5,8 @@ import asyncio
import cognee
import signal
from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
from cognee.modules.search.types import SearchType
app = modal.App("cognee-runner")

8
poetry.lock generated
View file

@ -1071,7 +1071,6 @@ files = [
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"},
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"},
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"},
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"},
{file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"},
@ -1082,7 +1081,6 @@ files = [
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"},
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"},
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"},
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"},
{file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"},
@ -3095,6 +3093,8 @@ optional = false
python-versions = "*"
files = [
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
]
[package.dependencies]
@ -3694,13 +3694,17 @@ proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "
[[package]]
name = "llama-index-core"
version = "0.12.14"
description = "Interface between LLMs and your data"
optional = true
python-versions = "<4.0,>=3.9"
files = [
{file = "llama_index_core-0.12.14-py3-none-any.whl", hash = "sha256:6fdb30e3fadf98e7df75f9db5d06f6a7f8503ca545a71e048d786ff88012bd50"},
{file = "llama_index_core-0.12.14.tar.gz", hash = "sha256:378bbf5bf4d1a8c692d3a980c1a6ed3be7a9afb676a4960429dea15f62d06cd3"},
]
[package.dependencies]