feat: enhance episode processing and node search with new entity types and update dependencies (#304)

* add custom entities

* improve prompts

* ruff fixes

* models + queue

* moar
This commit is contained in:
Daniel Chalef 2025-03-26 23:10:09 -07:00 committed by GitHub
parent 04203506d9
commit b73ca24cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 285 additions and 92 deletions

View file

@ -68,6 +68,7 @@ Available arguments:
- `--transport`: Choose the transport method (sse or stdio, default: sse)
- `--group-id`: Set a namespace for the graph (optional)
- `--destroy-graph`: Destroy all Graphiti graphs (use with caution)
- `--use-custom-entities`: Enable entity extraction using the predefined ENTITY_TYPES
### Docker Deployment
@ -249,11 +250,41 @@ Assistant: I'll search for node summaries related to the company.
[Assistant uses the search_nodes tool to find relevant entity summaries]
```
## Integrating with the Cursor IDE
To integrate the Graphiti MCP Server with the Cursor IDE, follow these steps:
1. Run the Graphiti MCP server using the SSE transport:
```bash
python graphiti_mcp_server.py --transport sse --use-custom-entities --group-id <your_group_id>
```
Hint: specify a `group_id` to retain prior graph data. If you do not specify a `group_id`, the server will create a new graph
2. Configure Cursor to connect to the Graphiti MCP server.
```json
{
"mcpServers": {
"Graphiti": {
"url": "http://localhost:8000/sse"
}
}
}
```
3. Add the Graphiti rules to Cursor's User Rules. See [cursor_rules.md](cursor_rules.md) for details.
4. Kick off an agent session in Cursor.
The integration enables AI assistants in Cursor to maintain persistent memory through Graphiti's knowledge graph capabilities.
## Requirements
- Python 3.10 or higher
- Neo4j database (version 5.26 or later required)
- OpenAI API key (for LLM operations)
- OpenAI API key (for LLM operations and embeddings)
- MCP-compatible client
## License

View file

@ -0,0 +1,34 @@
## Instructions for Using Graphiti's MCP Tools for Agent Memory
### Before Starting Any Task
- **Always search first:** Use the `search_nodes` tool to look for relevant preferences and procedures before beginning work.
- **Search for facts too:** Use the `search_facts` tool to discover relationships and factual information that may be relevant to your task.
- **Filter by entity type:** Specify `Preference`, `Procedure`, or `Requirement` in your node search to get targeted results.
- **Review all matches:** Carefully examine any preferences, procedures, or facts that match your current task.
### Always Save New or Updated Information
- **Capture requirements and preferences immediately:** When a user expresses a requirement or preference, use `add_episode` to store it right away.
- _Best practice:_ Split very long requirements into shorter, logical chunks.
- **Be explicit if something is an update to existing knowledge.** Only add what's changed or new to the graph.
- **Document procedures clearly:** When you discover how a user wants things done, record it as a procedure.
- **Record factual relationships:** When you learn about connections between entities, store these as facts.
- **Be specific with categories:** Label preferences and procedures with clear categories for better retrieval later.
### During Your Work
- **Respect discovered preferences:** Align your work with any preferences you've found.
- **Follow procedures exactly:** If you find a procedure for your current task, follow it step by step.
- **Apply relevant facts:** Use factual information to inform your decisions and recommendations.
- **Stay consistent:** Maintain consistency with previously identified preferences, procedures, and facts.
### Best Practices
- **Search before suggesting:** Always check if there's established knowledge before making recommendations.
- **Combine node and fact searches:** For complex tasks, search both nodes and facts to build a complete picture.
- **Use `center_node_uuid`:** When exploring related information, center your search around a specific node.
- **Prioritize specific matches:** More specific information takes precedence over general information.
- **Be proactive:** If you notice patterns in user behavior, consider storing them as preferences or procedures.
**Remember:** The knowledge graph is your memory. Use it consistently to provide personalized assistance that respects the user's established preferences, procedures, and factual context.

View file

@ -9,27 +9,110 @@ import logging
import os
import sys
import uuid
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any, Optional, TypedDict, TypeVar, Union, cast
from typing import Any, Optional, TypedDict, Union, cast
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from pydantic import BaseModel, Field
# graphiti_core imports
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_client import OpenAIClient
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_config_recipes import (
NODE_HYBRID_SEARCH_NODE_DISTANCE,
NODE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
DEFAULT_LLM_MODEL = 'gpt-4o'
class Requirement(BaseModel):
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
edge that the requirement is a requirement.
Instructions for identifying and extracting requirements:
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
2. Identify functional specifications that describe what the system should do
3. Pay attention to non-functional requirements like performance, security, or usability criteria
4. Extract constraints or limitations that must be adhered to
5. Focus on clear, specific, and measurable requirements rather than vague wishes
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
7. Include any dependencies between requirements when explicitly stated
8. Preserve the original intent and scope of the requirement
9. Categorize requirements appropriately based on their domain or function
"""
project_name: str = Field(
...,
description='The name of the project to which the requirement belongs.',
)
description: str = Field(
...,
description='Description of the requirement. Only use information mentioned in the context to write this description.',
)
class Preference(BaseModel):
"""A Preference represents a user's expressed like, dislike, or preference for something.
Instructions for identifying and extracting preferences:
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
2. Pay attention to comparative statements ("I prefer X over Y")
3. Consider the emotional tone when users mention certain topics
4. Extract only preferences that are clearly expressed, not assumptions
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
7. Only extract preferences directly stated by the user, not preferences of others they mention
8. Provide a concise but specific description that captures the nature of the preference
"""
category: str = Field(
...,
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
)
description: str = Field(
...,
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
)
class Procedure(BaseModel):
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
Instructions for identifying and extracting procedures:
1. Look for sequential instructions or steps ("First do X, then do Y")
2. Identify explicit directives or commands ("Always do X when Y happens")
3. Pay attention to conditional statements ("If X occurs, then do Y")
4. Extract procedures that have clear beginning and end points
5. Focus on actionable instructions rather than general information
6. Preserve the original sequence and dependencies between steps
7. Include any specified conditions or triggers for the procedure
8. Capture any stated purpose or goal of the procedure
9. Summarize complex procedures while maintaining critical details
"""
description: str = Field(
...,
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
)
ENTITY_TYPES: dict[str, BaseModel] = {
'Requirement': Requirement, # type: ignore
'Preference': Preference, # type: ignore
'Procedure': Procedure, # type: ignore
}
# Type definitions for API responses
class ErrorResponse(TypedDict):
@ -85,6 +168,7 @@ class GraphitiConfig(BaseModel):
openai_base_url: Optional[str] = None
model_name: Optional[str] = None
group_id: Optional[str] = None
use_custom_entities: bool = False
@classmethod
def from_env(cls) -> 'GraphitiConfig':
@ -158,14 +242,6 @@ mcp = FastMCP(
# Initialize Graphiti client
graphiti_client: Optional[Graphiti] = None
# Type for functions that can be wrapped with graphiti_error_handler
T = TypeVar('T')
GraphitiFunc = Callable[..., Awaitable[T]]
# Note: We've removed the error handler decorator in favor of inline error handling
# This is to avoid type checking issues with the global graphiti_client variable
async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_graph: bool = False):
"""Initialize the Graphiti client with the provided settings.
@ -185,7 +261,6 @@ async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_gr
if config.model_name:
llm_config.model = config.model_name
llm_client = OpenAIClient(config=llm_config)
logger.info('Using OpenAI as LLM client')
else:
raise ValueError('OPENAI_API_KEY must be set when not using a custom LLM client')
@ -219,16 +294,55 @@ def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
Returns:
A dictionary representation of the edge with serialized dates and excluded embeddings
"""
# Convert to dict using Pydantic's model_dump method with mode='json'
# This automatically handles datetime serialization and other complex types
return edge.model_dump(
mode='json', # Properly handle datetime serialization for JSON
mode='json',
exclude={
'fact_embedding', # Exclude embedding data
'fact_embedding',
},
)
# Dictionary to store queues for each group_id
# Each queue is a list of tasks to be processed sequentially
episode_queues: dict[str, asyncio.Queue] = {}
# Dictionary to track if a worker is running for each group_id
queue_workers: dict[str, bool] = {}
async def process_episode_queue(group_id: str):
"""Process episodes for a specific group_id sequentially.
This function runs as a long-lived task that processes episodes
from the queue one at a time.
"""
global queue_workers
logger.info(f'Starting episode queue worker for group_id: {group_id}')
queue_workers[group_id] = True
try:
while True:
# Get the next episode processing function from the queue
# This will wait if the queue is empty
process_func = await episode_queues[group_id].get()
try:
# Process the episode
await process_func()
except Exception as e:
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
finally:
# Mark the task as done regardless of success/failure
episode_queues[group_id].task_done()
except asyncio.CancelledError:
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
except Exception as e:
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
finally:
queue_workers[group_id] = False
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
@mcp.tool()
async def add_episode(
name: str,
@ -240,6 +354,9 @@ async def add_episode(
) -> Union[SuccessResponse, ErrorResponse]:
"""Add an episode to the Graphiti knowledge graph. This is the primary way to add information to the graph.
This function returns immediately and processes the episode addition in the background.
Episodes for the same group_id are processed sequentially to avoid race conditions.
Args:
name (str): Name of the episode
episode_body (str): The content of the episode. When source='json', this must be a properly escaped JSON string,
@ -265,7 +382,7 @@ async def add_episode(
)
# Adding structured JSON data
# NOTE: episode_body must be a properly escaped JSON string
# NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes
add_episode(
name="Customer Profile",
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
@ -273,14 +390,6 @@ async def add_episode(
source_description="CRM data"
)
# Adding more complex JSON with arrays and nested objects
add_episode(
name="Product Catalog",
episode_body='\{"catalog": \{"company": "Tech Solutions Inc.", "products": [\{"id": "P001", "name": "Product X", "features": ["Feature A", "Feature B"]\}]\}\}',
source="json",
source_description="Product catalog"
)
# Adding message-style content
add_episode(
name="Customer Conversation",
@ -298,7 +407,7 @@ async def add_episode(
- Entities will be created from appropriate JSON properties
- Relationships between entities will be established based on the JSON structure
"""
global graphiti_client
global graphiti_client, episode_queues, queue_workers
if graphiti_client is None:
return {'error': 'Graphiti client not initialized'}
@ -320,29 +429,59 @@ async def add_episode(
# We've already checked that graphiti_client is not None above
# This assert statement helps type checkers understand that graphiti_client is defined
# from this point forward in the function
assert graphiti_client is not None, 'graphiti_client should not be None here'
# Use cast to help the type checker understand that graphiti_client is not None
# This doesn't change the runtime behavior, only helps with static type checking
client = cast(Graphiti, graphiti_client)
# Type checking will now know that client is a Graphiti instance (not None)
# and group_id is a str, not Optional[str]
await client.add_episode(
name=name,
episode_body=episode_body,
source=source_type,
source_description=source_description,
group_id=group_id_str, # Using the string version of group_id
uuid=uuid,
reference_time=datetime.now(timezone.utc),
)
return {'message': f"Episode '{name}' added successfully"}
# Define the episode processing function
async def process_episode():
try:
logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}")
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
await client.add_episode(
name=name,
episode_body=episode_body,
source=source_type,
source_description=source_description,
group_id=group_id_str, # Using the string version of group_id
uuid=uuid,
reference_time=datetime.now(timezone.utc),
entity_types=entity_types,
)
logger.info(f"Episode '{name}' added successfully")
logger.info(f"Building communities after episode '{name}'")
await client.build_communities()
logger.info(f"Episode '{name}' processed successfully")
except Exception as e:
error_msg = str(e)
logger.error(
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
)
# Initialize queue for this group_id if it doesn't exist
if group_id_str not in episode_queues:
episode_queues[group_id_str] = asyncio.Queue()
# Add the episode processing function to the queue
await episode_queues[group_id_str].put(process_episode)
# Start a worker for this queue if one isn't already running
if not queue_workers.get(group_id_str, False):
asyncio.create_task(process_episode_queue(group_id_str))
# Return immediately with a success message
return {
'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
}
except Exception as e:
error_msg = str(e)
logger.error(f'Error adding episode: {error_msg}')
return {'error': f'Error adding episode: {error_msg}'}
logger.error(f'Error queuing episode task: {error_msg}')
return {'error': f'Error queuing episode task: {error_msg}'}
@mcp.tool()
@ -351,15 +490,19 @@ async def search_nodes(
group_ids: Optional[list[str]] = None,
max_nodes: int = 10,
center_node_uuid: Optional[str] = None,
entity: str = '', # cursor seems to break with None
) -> Union[NodeSearchResponse, ErrorResponse]:
"""Search the Graphiti knowledge graph for relevant node summaries.
These contain a summary of all of a node's relationships with other nodes.
Note: entity is a single entity type to filter results (permitted: "Preference", "Procedure").
Args:
query: The search query
group_ids: Optional list of group IDs to filter results
max_nodes: Maximum number of nodes to return (default: 10)
center_node_uuid: Optional UUID of a node to center the search around
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
"""
global graphiti_client
@ -373,9 +516,16 @@ async def search_nodes(
)
# Configure the search
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
if center_node_uuid is not None:
search_config = NODE_HYBRID_SEARCH_NODE_DISTANCE.model_copy(deep=True)
else:
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
search_config.limit = max_nodes
filters = SearchFilters()
if entity != '':
filters.node_labels = [entity]
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
@ -388,7 +538,7 @@ async def search_nodes(
config=search_config,
group_ids=effective_group_ids,
center_node_uuid=center_node_uuid,
search_filter=SearchFilters(),
search_filter=filters,
)
if not search_results.nodes:
@ -658,13 +808,10 @@ async def get_status() -> StatusResponse:
}
def create_llm_client(
client_type: str = 'openai', api_key: Optional[str] = None, model: Optional[str] = None
) -> LLMClient:
def create_llm_client(api_key: Optional[str] = None, model: Optional[str] = None) -> LLMClient:
"""Create an OpenAI LLM client.
Args:
client_type: Type of LLM client to create (only 'openai' or 'openai_generic' supported)
api_key: API key for the OpenAI service
model: Model name to use
@ -703,6 +850,11 @@ async def initialize_server() -> MCPConfig:
# OpenAI is the only supported LLM client
parser.add_argument('--model', help='Model name to use with the LLM client')
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
parser.add_argument(
'--use-custom-entities',
action='store_true',
help='Enable entity extraction using the predefined ENTITY_TYPES',
)
args = parser.parse_args()
@ -714,18 +866,23 @@ async def initialize_server() -> MCPConfig:
config.group_id = f'graph_{uuid.uuid4().hex[:8]}'
logger.info(f'Generated random group_id: {config.group_id}')
# Set use_custom_entities flag if specified
if args.use_custom_entities:
config.use_custom_entities = True
logger.info('Entity extraction enabled using predefined ENTITY_TYPES')
else:
logger.info('Entity extraction disabled (no custom entities will be used)')
llm_client = None
# Create OpenAI client if model is specified or if OPENAI_API_KEY is available
if args.model or config.openai_api_key:
# Override model from command line if specified
if args.model:
config.model_name = args.model
config.model_name = args.model or DEFAULT_LLM_MODEL
# Create the OpenAI client
llm_client = create_llm_client(
client_type='openai', api_key=config.openai_api_key, model=config.model_name
)
llm_client = create_llm_client(api_key=config.openai_api_key, model=config.model_name)
# Initialize Graphiti with the specified LLM client
await initialize_graphiti(llm_client, destroy_graph=args.destroy_graph)

View file

@ -7,5 +7,5 @@ requires-python = ">=3.10"
dependencies = [
"mcp>=1.5.0",
"openai>=1.68.2",
"graphiti-core>=0.8.1",
"graphiti-core>=0.8.2",
]

41
mcp_server/uv.lock generated
View file

@ -10,24 +10,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 },
]
[[package]]
name = "anthropic"
version = "0.49.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/86/e3/a88c8494ce4d1a88252b9e053607e885f9b14d0a32273d47b727cbee4228/anthropic-0.49.0.tar.gz", hash = "sha256:c09e885b0f674b9119b4f296d8508907f6cff0009bc20d5cf6b35936c40b4398", size = 210016 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/74/5d90ad14d55fbe3f9c474fdcb6e34b4bed99e3be8efac98734a5ddce88c1/anthropic-0.49.0-py3-none-any.whl", hash = "sha256:bbc17ad4e7094988d2fa86b87753ded8dce12498f4b85fe5810f208f454a8375", size = 243368 },
]
[[package]]
name = "anyio"
version = "4.9.0"
@ -102,12 +84,10 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.8.1"
source = { directory = "../" }
version = "0.8.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anthropic" },
{ name = "diskcache" },
{ name = "mcp" },
{ name = "neo4j" },
{ name = "numpy" },
{ name = "openai" },
@ -115,18 +95,9 @@ dependencies = [
{ name = "python-dotenv" },
{ name = "tenacity" },
]
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.49.0,<0.50.0" },
{ name = "diskcache", specifier = ">=5.6.3,<6.0.0" },
{ name = "mcp", specifier = ">=1.5.0,<2.0.0" },
{ name = "neo4j", specifier = ">=5.23.0,<6.0.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.53.0,<2.0.0" },
{ name = "pydantic", specifier = ">=2.8.2,<3.0.0" },
{ name = "python-dotenv", specifier = ">=1.0.1,<2.0.0" },
{ name = "tenacity", specifier = "==9.0.0" },
sdist = { url = "https://files.pythonhosted.org/packages/b7/29/f0d74adc687514a226a164340a7fa9254f9cffa2ff7ab353fc1edf4016cb/graphiti_core-0.8.2.tar.gz", hash = "sha256:477a9172728b92ba1875222f8d9c6337d9fa465363b75afd67db9a159c39d379", size = 62464 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/19/4b/d2065544beb24402bb41197b2adf7ed7af287453c82290d1ca757532f6d3/graphiti_core-0.8.2-py3-none-any.whl", hash = "sha256:ce992f9f3dee168d3ca442e0d0fd685a1a66dee4df94c47909037295aceac45e", size = 95052 },
]
[[package]]
@ -274,7 +245,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "graphiti-core", directory = "../" },
{ name = "graphiti-core", specifier = ">=0.8.2" },
{ name = "mcp", specifier = ">=1.5.0" },
{ name = "openai", specifier = ">=1.68.2" },
]