From 806298d508b8bc22a1665f828020284ca177ab2f Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Wed, 8 Oct 2025 18:36:08 +0100 Subject: [PATCH] create cognee_client for handling api calls --- cognee-mcp/pyproject.toml | 1 + cognee-mcp/src/cognee_client.py | 339 ++++++++++++++++++++++++++++++++ cognee-mcp/uv.lock | 30 +-- 3 files changed, 356 insertions(+), 14 deletions(-) create mode 100644 cognee-mcp/src/cognee_client.py diff --git a/cognee-mcp/pyproject.toml b/cognee-mcp/pyproject.toml index bc0ebeac5..c5a21a1a7 100644 --- a/cognee-mcp/pyproject.toml +++ b/cognee-mcp/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "fastmcp>=2.10.0,<3.0.0", "mcp>=1.12.0,<2.0.0", "uv>=0.6.3,<1.0.0", + "httpx>=0.27.0,<1.0.0", ] authors = [ diff --git a/cognee-mcp/src/cognee_client.py b/cognee-mcp/src/cognee_client.py new file mode 100644 index 000000000..e22b8c060 --- /dev/null +++ b/cognee-mcp/src/cognee_client.py @@ -0,0 +1,339 @@ +""" +Cognee Client abstraction that supports both direct function calls and HTTP API calls. + +This module provides a unified interface for interacting with Cognee, supporting: +- Direct mode: Directly imports and calls cognee functions (default behavior) +- API mode: Makes HTTP requests to a running Cognee FastAPI server +""" + +import sys +from typing import Optional, Any, List, Dict +from uuid import UUID +from contextlib import redirect_stdout +import httpx +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +class CogneeClient: + """ + Unified client for interacting with Cognee via direct calls or HTTP API. + + Parameters + ---------- + api_url : str, optional + Base URL of the Cognee API server (e.g., "http://localhost:8000"). + If None, uses direct cognee function calls. + api_token : str, optional + Authentication token for API requests. Required if api_url is provided. + """ + + def __init__(self, api_url: Optional[str] = None, api_token: Optional[str] = None): + self.api_url = api_url.rstrip("/") if api_url else None + self.api_token = api_token + self.use_api = api_url is not None + + if self.use_api: + logger.info(f"Cognee client initialized in API mode: {self.api_url}") + self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long operations + else: + logger.info("Cognee client initialized in direct mode") + # Import cognee only if we're using direct mode + import cognee as _cognee + + self.cognee = _cognee + + def _get_headers(self) -> Dict[str, str]: + """Get headers for API requests.""" + headers = {"Content-Type": "application/json"} + if self.api_token: + headers["Authorization"] = f"Bearer {self.api_token}" + return headers + + async def add( + self, data: Any, dataset_name: str = "main_dataset", node_set: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Add data to Cognee for processing. + + Parameters + ---------- + data : Any + Data to add (text, file path, etc.) + dataset_name : str + Name of the dataset to add data to + node_set : List[str], optional + List of node identifiers for graph organization + + Returns + ------- + Dict[str, Any] + Result of the add operation + """ + if self.use_api: + # API mode: Make HTTP request + endpoint = f"{self.api_url}/api/v1/add" + + # For API mode, we need to handle file uploads differently + # For now, we'll assume data is text content + files = {"data": ("data.txt", str(data), "text/plain")} + form_data = { + "datasetName": dataset_name, + } + if node_set: + form_data["node_set"] = node_set + + response = await self.client.post( + endpoint, + files=files, + data=form_data, + headers={"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}, + ) + response.raise_for_status() + return response.json() + else: + # Direct mode: Call cognee directly + with redirect_stdout(sys.stderr): + await self.cognee.add(data, dataset_name=dataset_name, node_set=node_set) + return {"status": "success", "message": "Data added successfully"} + + async def cognify( + self, + datasets: Optional[List[str]] = None, + custom_prompt: Optional[str] = None, + graph_model: Any = None, + ) -> Dict[str, Any]: + """ + Transform data into a knowledge graph. + + Parameters + ---------- + datasets : List[str], optional + List of dataset names to process + custom_prompt : str, optional + Custom prompt for entity extraction + graph_model : Any, optional + Custom graph model (only used in direct mode) + + Returns + ------- + Dict[str, Any] + Result of the cognify operation + """ + if self.use_api: + # API mode: Make HTTP request + endpoint = f"{self.api_url}/api/v1/cognify" + payload = { + "datasets": datasets or ["main_dataset"], + "run_in_background": False, + } + if custom_prompt: + payload["custom_prompt"] = custom_prompt + + response = await self.client.post(endpoint, json=payload, headers=self._get_headers()) + response.raise_for_status() + return response.json() + else: + # Direct mode: Call cognee directly + with redirect_stdout(sys.stderr): + kwargs = {} + if custom_prompt: + kwargs["custom_prompt"] = custom_prompt + if graph_model: + kwargs["graph_model"] = graph_model + + await self.cognee.cognify(**kwargs) + return {"status": "success", "message": "Cognify completed successfully"} + + async def search( + self, + query_text: str, + query_type: str, + datasets: Optional[List[str]] = None, + system_prompt: Optional[str] = None, + top_k: int = 10, + ) -> Any: + """ + Search the knowledge graph. + + Parameters + ---------- + query_text : str + The search query + query_type : str + Type of search (e.g., "GRAPH_COMPLETION", "INSIGHTS", etc.) + datasets : List[str], optional + List of datasets to search + system_prompt : str, optional + System prompt for completion searches + top_k : int + Maximum number of results + + Returns + ------- + Any + Search results + """ + if self.use_api: + # API mode: Make HTTP request + endpoint = f"{self.api_url}/api/v1/search" + payload = {"query": query_text, "search_type": query_type.upper(), "top_k": top_k} + if datasets: + payload["datasets"] = datasets + if system_prompt: + payload["system_prompt"] = system_prompt + + response = await self.client.post(endpoint, json=payload, headers=self._get_headers()) + response.raise_for_status() + return response.json() + else: + # Direct mode: Call cognee directly + from cognee.modules.search.types import SearchType + + with redirect_stdout(sys.stderr): + results = await self.cognee.search( + query_type=SearchType[query_type.upper()], query_text=query_text + ) + return results + + async def delete(self, data_id: UUID, dataset_id: UUID, mode: str = "soft") -> Dict[str, Any]: + """ + Delete data from a dataset. + + Parameters + ---------- + data_id : UUID + ID of the data to delete + dataset_id : UUID + ID of the dataset containing the data + mode : str + Deletion mode ("soft" or "hard") + + Returns + ------- + Dict[str, Any] + Result of the deletion + """ + if self.use_api: + # API mode: Make HTTP request + endpoint = f"{self.api_url}/api/v1/delete" + params = {"data_id": str(data_id), "dataset_id": str(dataset_id), "mode": mode} + + response = await self.client.delete( + endpoint, params=params, headers=self._get_headers() + ) + response.raise_for_status() + return response.json() + else: + # Direct mode: Call cognee directly + from cognee.modules.users.methods import get_default_user + + with redirect_stdout(sys.stderr): + user = await get_default_user() + result = await self.cognee.delete( + data_id=data_id, dataset_id=dataset_id, mode=mode, user=user + ) + return result + + async def prune_data(self) -> Dict[str, Any]: + """ + Prune all data from the knowledge graph. + + Returns + ------- + Dict[str, Any] + Result of the prune operation + """ + if self.use_api: + # Note: The API doesn't expose a prune endpoint, so we'll need to handle this + # For now, raise an error + raise NotImplementedError("Prune operation is not available via API") + else: + # Direct mode: Call cognee directly + with redirect_stdout(sys.stderr): + await self.cognee.prune.prune_data() + return {"status": "success", "message": "Data pruned successfully"} + + async def prune_system(self, metadata: bool = True) -> Dict[str, Any]: + """ + Prune system data from the knowledge graph. + + Parameters + ---------- + metadata : bool + Whether to prune metadata + + Returns + ------- + Dict[str, Any] + Result of the prune operation + """ + if self.use_api: + # Note: The API doesn't expose a prune endpoint + raise NotImplementedError("Prune system operation is not available via API") + else: + # Direct mode: Call cognee directly + with redirect_stdout(sys.stderr): + await self.cognee.prune.prune_system(metadata=metadata) + return {"status": "success", "message": "System pruned successfully"} + + async def get_pipeline_status(self, dataset_ids: List[UUID], pipeline_name: str) -> str: + """ + Get the status of a pipeline run. + + Parameters + ---------- + dataset_ids : List[UUID] + List of dataset IDs + pipeline_name : str + Name of the pipeline + + Returns + ------- + str + Status information + """ + if self.use_api: + # Note: This would need a custom endpoint on the API side + raise NotImplementedError("Pipeline status is not available via API") + else: + # Direct mode: Call cognee directly + from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status + + with redirect_stdout(sys.stderr): + status = await get_pipeline_status(dataset_ids, pipeline_name) + return str(status) + + async def list_datasets(self) -> List[Dict[str, Any]]: + """ + List all datasets. + + Returns + ------- + List[Dict[str, Any]] + List of datasets + """ + if self.use_api: + # API mode: Make HTTP request + endpoint = f"{self.api_url}/api/v1/datasets" + response = await self.client.get(endpoint, headers=self._get_headers()) + response.raise_for_status() + return response.json() + else: + # Direct mode: Call cognee directly + from cognee.modules.users.methods import get_default_user + from cognee.modules.data.methods import get_datasets + + with redirect_stdout(sys.stderr): + user = await get_default_user() + datasets = await get_datasets(user.id) + return [ + {"id": str(d.id), "name": d.name, "created_at": str(d.created_at)} + for d in datasets + ] + + async def close(self): + """Close the HTTP client if in API mode.""" + if self.use_api and hasattr(self, "client"): + await self.client.aclose() diff --git a/cognee-mcp/uv.lock b/cognee-mcp/uv.lock index e31741613..dbbc6542c 100644 --- a/cognee-mcp/uv.lock +++ b/cognee-mcp/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'emscripten'", @@ -737,6 +737,7 @@ source = { editable = "." } dependencies = [ { name = "cognee", extra = ["codegraph", "docs", "gemini", "huggingface", "neo4j", "postgres"] }, { name = "fastmcp" }, + { name = "httpx" }, { name = "mcp" }, { name = "uv" }, ] @@ -750,6 +751,7 @@ dev = [ requires-dist = [ { name = "cognee", extras = ["postgres", "codegraph", "gemini", "huggingface", "docs", "neo4j"], specifier = "==0.3.4" }, { name = "fastmcp", specifier = ">=2.10.0,<3.0.0" }, + { name = "httpx", specifier = ">=0.27.0,<1.0.0" }, { name = "mcp", specifier = ">=1.12.0,<2.0.0" }, { name = "uv", specifier = ">=0.6.3,<1.0.0" }, ] @@ -1026,7 +1028,7 @@ version = "3.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "docstring-parser", marker = "python_full_version < '4.0'" }, + { name = "docstring-parser", marker = "python_full_version < '4'" }, { name = "rich" }, { name = "rich-rst" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, @@ -1309,17 +1311,17 @@ name = "fastembed" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub" }, - { name = "loguru" }, - { name = "mmh3" }, + { name = "huggingface-hub", marker = "python_full_version < '3.13'" }, + { name = "loguru", marker = "python_full_version < '3.13'" }, + { name = "mmh3", marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "onnxruntime" }, - { name = "pillow" }, - { name = "py-rust-stemmers" }, - { name = "requests" }, - { name = "tokenizers" }, - { name = "tqdm" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, + { name = "onnxruntime", marker = "python_full_version < '3.13'" }, + { name = "pillow", marker = "python_full_version < '3.13'" }, + { name = "py-rust-stemmers", marker = "python_full_version < '3.13'" }, + { name = "requests", marker = "python_full_version < '3.13'" }, + { name = "tokenizers", marker = "python_full_version < '3.13'" }, + { name = "tqdm", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c6/f4/036a656c605f63dc25f11284f60f69900a54a19c513e1ae60d21d6977e75/fastembed-0.6.0.tar.gz", hash = "sha256:5c9ead25f23449535b07243bbe1f370b820dcc77ec2931e61674e3fe7ff24733", size = 50731, upload-time = "2025-02-26T13:50:33.031Z" } wheels = [ @@ -2526,8 +2528,8 @@ name = "loguru" version = "0.7.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "win32-setctime", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } wheels = [