create cognee_client for handling api calls
This commit is contained in:
parent
c54c3546fc
commit
806298d508
3 changed files with 356 additions and 14 deletions
|
|
@ -13,6 +13,7 @@ dependencies = [
|
||||||
"fastmcp>=2.10.0,<3.0.0",
|
"fastmcp>=2.10.0,<3.0.0",
|
||||||
"mcp>=1.12.0,<2.0.0",
|
"mcp>=1.12.0,<2.0.0",
|
||||||
"uv>=0.6.3,<1.0.0",
|
"uv>=0.6.3,<1.0.0",
|
||||||
|
"httpx>=0.27.0,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
||||||
339
cognee-mcp/src/cognee_client.py
Normal file
339
cognee-mcp/src/cognee_client.py
Normal file
|
|
@ -0,0 +1,339 @@
|
||||||
|
"""
|
||||||
|
Cognee Client abstraction that supports both direct function calls and HTTP API calls.
|
||||||
|
|
||||||
|
This module provides a unified interface for interacting with Cognee, supporting:
|
||||||
|
- Direct mode: Directly imports and calls cognee functions (default behavior)
|
||||||
|
- API mode: Makes HTTP requests to a running Cognee FastAPI server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Any, List, Dict
|
||||||
|
from uuid import UUID
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
import httpx
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeClient:
|
||||||
|
"""
|
||||||
|
Unified client for interacting with Cognee via direct calls or HTTP API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
api_url : str, optional
|
||||||
|
Base URL of the Cognee API server (e.g., "http://localhost:8000").
|
||||||
|
If None, uses direct cognee function calls.
|
||||||
|
api_token : str, optional
|
||||||
|
Authentication token for API requests. Required if api_url is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_url: Optional[str] = None, api_token: Optional[str] = None):
|
||||||
|
self.api_url = api_url.rstrip("/") if api_url else None
|
||||||
|
self.api_token = api_token
|
||||||
|
self.use_api = api_url is not None
|
||||||
|
|
||||||
|
if self.use_api:
|
||||||
|
logger.info(f"Cognee client initialized in API mode: {self.api_url}")
|
||||||
|
self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long operations
|
||||||
|
else:
|
||||||
|
logger.info("Cognee client initialized in direct mode")
|
||||||
|
# Import cognee only if we're using direct mode
|
||||||
|
import cognee as _cognee
|
||||||
|
|
||||||
|
self.cognee = _cognee
|
||||||
|
|
||||||
|
def _get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get headers for API requests."""
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.api_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def add(
|
||||||
|
self, data: Any, dataset_name: str = "main_dataset", node_set: Optional[List[str]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add data to Cognee for processing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : Any
|
||||||
|
Data to add (text, file path, etc.)
|
||||||
|
dataset_name : str
|
||||||
|
Name of the dataset to add data to
|
||||||
|
node_set : List[str], optional
|
||||||
|
List of node identifiers for graph organization
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Result of the add operation
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# API mode: Make HTTP request
|
||||||
|
endpoint = f"{self.api_url}/api/v1/add"
|
||||||
|
|
||||||
|
# For API mode, we need to handle file uploads differently
|
||||||
|
# For now, we'll assume data is text content
|
||||||
|
files = {"data": ("data.txt", str(data), "text/plain")}
|
||||||
|
form_data = {
|
||||||
|
"datasetName": dataset_name,
|
||||||
|
}
|
||||||
|
if node_set:
|
||||||
|
form_data["node_set"] = node_set
|
||||||
|
|
||||||
|
response = await self.client.post(
|
||||||
|
endpoint,
|
||||||
|
files=files,
|
||||||
|
data=form_data,
|
||||||
|
headers={"Authorization": f"Bearer {self.api_token}"} if self.api_token else {},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
await self.cognee.add(data, dataset_name=dataset_name, node_set=node_set)
|
||||||
|
return {"status": "success", "message": "Data added successfully"}
|
||||||
|
|
||||||
|
async def cognify(
|
||||||
|
self,
|
||||||
|
datasets: Optional[List[str]] = None,
|
||||||
|
custom_prompt: Optional[str] = None,
|
||||||
|
graph_model: Any = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Transform data into a knowledge graph.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
datasets : List[str], optional
|
||||||
|
List of dataset names to process
|
||||||
|
custom_prompt : str, optional
|
||||||
|
Custom prompt for entity extraction
|
||||||
|
graph_model : Any, optional
|
||||||
|
Custom graph model (only used in direct mode)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Result of the cognify operation
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# API mode: Make HTTP request
|
||||||
|
endpoint = f"{self.api_url}/api/v1/cognify"
|
||||||
|
payload = {
|
||||||
|
"datasets": datasets or ["main_dataset"],
|
||||||
|
"run_in_background": False,
|
||||||
|
}
|
||||||
|
if custom_prompt:
|
||||||
|
payload["custom_prompt"] = custom_prompt
|
||||||
|
|
||||||
|
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
kwargs = {}
|
||||||
|
if custom_prompt:
|
||||||
|
kwargs["custom_prompt"] = custom_prompt
|
||||||
|
if graph_model:
|
||||||
|
kwargs["graph_model"] = graph_model
|
||||||
|
|
||||||
|
await self.cognee.cognify(**kwargs)
|
||||||
|
return {"status": "success", "message": "Cognify completed successfully"}
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
query_type: str,
|
||||||
|
datasets: Optional[List[str]] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Search the knowledge graph.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query_text : str
|
||||||
|
The search query
|
||||||
|
query_type : str
|
||||||
|
Type of search (e.g., "GRAPH_COMPLETION", "INSIGHTS", etc.)
|
||||||
|
datasets : List[str], optional
|
||||||
|
List of datasets to search
|
||||||
|
system_prompt : str, optional
|
||||||
|
System prompt for completion searches
|
||||||
|
top_k : int
|
||||||
|
Maximum number of results
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Any
|
||||||
|
Search results
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# API mode: Make HTTP request
|
||||||
|
endpoint = f"{self.api_url}/api/v1/search"
|
||||||
|
payload = {"query": query_text, "search_type": query_type.upper(), "top_k": top_k}
|
||||||
|
if datasets:
|
||||||
|
payload["datasets"] = datasets
|
||||||
|
if system_prompt:
|
||||||
|
payload["system_prompt"] = system_prompt
|
||||||
|
|
||||||
|
response = await self.client.post(endpoint, json=payload, headers=self._get_headers())
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
results = await self.cognee.search(
|
||||||
|
query_type=SearchType[query_type.upper()], query_text=query_text
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def delete(self, data_id: UUID, dataset_id: UUID, mode: str = "soft") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Delete data from a dataset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data_id : UUID
|
||||||
|
ID of the data to delete
|
||||||
|
dataset_id : UUID
|
||||||
|
ID of the dataset containing the data
|
||||||
|
mode : str
|
||||||
|
Deletion mode ("soft" or "hard")
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Result of the deletion
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# API mode: Make HTTP request
|
||||||
|
endpoint = f"{self.api_url}/api/v1/delete"
|
||||||
|
params = {"data_id": str(data_id), "dataset_id": str(dataset_id), "mode": mode}
|
||||||
|
|
||||||
|
response = await self.client.delete(
|
||||||
|
endpoint, params=params, headers=self._get_headers()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
user = await get_default_user()
|
||||||
|
result = await self.cognee.delete(
|
||||||
|
data_id=data_id, dataset_id=dataset_id, mode=mode, user=user
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def prune_data(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prune all data from the knowledge graph.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Result of the prune operation
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# Note: The API doesn't expose a prune endpoint, so we'll need to handle this
|
||||||
|
# For now, raise an error
|
||||||
|
raise NotImplementedError("Prune operation is not available via API")
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
await self.cognee.prune.prune_data()
|
||||||
|
return {"status": "success", "message": "Data pruned successfully"}
|
||||||
|
|
||||||
|
async def prune_system(self, metadata: bool = True) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prune system data from the knowledge graph.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metadata : bool
|
||||||
|
Whether to prune metadata
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Result of the prune operation
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# Note: The API doesn't expose a prune endpoint
|
||||||
|
raise NotImplementedError("Prune system operation is not available via API")
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
await self.cognee.prune.prune_system(metadata=metadata)
|
||||||
|
return {"status": "success", "message": "System pruned successfully"}
|
||||||
|
|
||||||
|
async def get_pipeline_status(self, dataset_ids: List[UUID], pipeline_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the status of a pipeline run.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset_ids : List[UUID]
|
||||||
|
List of dataset IDs
|
||||||
|
pipeline_name : str
|
||||||
|
Name of the pipeline
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Status information
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# Note: This would need a custom endpoint on the API side
|
||||||
|
raise NotImplementedError("Pipeline status is not available via API")
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||||
|
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
status = await get_pipeline_status(dataset_ids, pipeline_name)
|
||||||
|
return str(status)
|
||||||
|
|
||||||
|
async def list_datasets(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all datasets.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Dict[str, Any]]
|
||||||
|
List of datasets
|
||||||
|
"""
|
||||||
|
if self.use_api:
|
||||||
|
# API mode: Make HTTP request
|
||||||
|
endpoint = f"{self.api_url}/api/v1/datasets"
|
||||||
|
response = await self.client.get(endpoint, headers=self._get_headers())
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
# Direct mode: Call cognee directly
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.modules.data.methods import get_datasets
|
||||||
|
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
user = await get_default_user()
|
||||||
|
datasets = await get_datasets(user.id)
|
||||||
|
return [
|
||||||
|
{"id": str(d.id), "name": d.name, "created_at": str(d.created_at)}
|
||||||
|
for d in datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP client if in API mode."""
|
||||||
|
if self.use_api and hasattr(self, "client"):
|
||||||
|
await self.client.aclose()
|
||||||
30
cognee-mcp/uv.lock
generated
30
cognee-mcp/uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'emscripten'",
|
"python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'emscripten'",
|
||||||
|
|
@ -737,6 +737,7 @@ source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "cognee", extra = ["codegraph", "docs", "gemini", "huggingface", "neo4j", "postgres"] },
|
{ name = "cognee", extra = ["codegraph", "docs", "gemini", "huggingface", "neo4j", "postgres"] },
|
||||||
{ name = "fastmcp" },
|
{ name = "fastmcp" },
|
||||||
|
{ name = "httpx" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "uv" },
|
{ name = "uv" },
|
||||||
]
|
]
|
||||||
|
|
@ -750,6 +751,7 @@ dev = [
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "cognee", extras = ["postgres", "codegraph", "gemini", "huggingface", "docs", "neo4j"], specifier = "==0.3.4" },
|
{ name = "cognee", extras = ["postgres", "codegraph", "gemini", "huggingface", "docs", "neo4j"], specifier = "==0.3.4" },
|
||||||
{ name = "fastmcp", specifier = ">=2.10.0,<3.0.0" },
|
{ name = "fastmcp", specifier = ">=2.10.0,<3.0.0" },
|
||||||
|
{ name = "httpx", specifier = ">=0.27.0,<1.0.0" },
|
||||||
{ name = "mcp", specifier = ">=1.12.0,<2.0.0" },
|
{ name = "mcp", specifier = ">=1.12.0,<2.0.0" },
|
||||||
{ name = "uv", specifier = ">=0.6.3,<1.0.0" },
|
{ name = "uv", specifier = ">=0.6.3,<1.0.0" },
|
||||||
]
|
]
|
||||||
|
|
@ -1026,7 +1028,7 @@ version = "3.24.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "attrs" },
|
{ name = "attrs" },
|
||||||
{ name = "docstring-parser", marker = "python_full_version < '4.0'" },
|
{ name = "docstring-parser", marker = "python_full_version < '4'" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "rich-rst" },
|
{ name = "rich-rst" },
|
||||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||||
|
|
@ -1309,17 +1311,17 @@ name = "fastembed"
|
||||||
version = "0.6.0"
|
version = "0.6.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "loguru" },
|
{ name = "loguru", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "mmh3" },
|
{ name = "mmh3", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" },
|
||||||
{ name = "onnxruntime" },
|
{ name = "onnxruntime", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "py-rust-stemmers" },
|
{ name = "py-rust-stemmers", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "requests" },
|
{ name = "requests", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "tokenizers" },
|
{ name = "tokenizers", marker = "python_full_version < '3.13'" },
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm", marker = "python_full_version < '3.13'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/f4/036a656c605f63dc25f11284f60f69900a54a19c513e1ae60d21d6977e75/fastembed-0.6.0.tar.gz", hash = "sha256:5c9ead25f23449535b07243bbe1f370b820dcc77ec2931e61674e3fe7ff24733", size = 50731, upload-time = "2025-02-26T13:50:33.031Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/c6/f4/036a656c605f63dc25f11284f60f69900a54a19c513e1ae60d21d6977e75/fastembed-0.6.0.tar.gz", hash = "sha256:5c9ead25f23449535b07243bbe1f370b820dcc77ec2931e61674e3fe7ff24733", size = 50731, upload-time = "2025-02-26T13:50:33.031Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
|
@ -2526,8 +2528,8 @@ name = "loguru"
|
||||||
version = "0.7.3"
|
version = "0.7.3"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
{ name = "colorama", marker = "python_full_version < '3.13' and sys_platform == 'win32'" },
|
||||||
{ name = "win32-setctime", marker = "sys_platform == 'win32'" },
|
{ name = "win32-setctime", marker = "python_full_version < '3.13' and sys_platform == 'win32'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue