From b73ca24cfb42dd04b220cd61c98bd009d5dc779f Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 26 Mar 2025 23:10:09 -0700 Subject: [PATCH] feat: enhance episode processing and node search with new entity types and update dependencies (#304) * add custom entities * improve prompts * ruff fixes * models + queue * moar --- mcp_server/README.md | 33 +++- mcp_server/cursor_rules.md | 34 ++++ mcp_server/graphiti_mcp_server.py | 267 ++++++++++++++++++++++++------ mcp_server/pyproject.toml | 2 +- mcp_server/uv.lock | 41 +---- 5 files changed, 285 insertions(+), 92 deletions(-) create mode 100644 mcp_server/cursor_rules.md diff --git a/mcp_server/README.md b/mcp_server/README.md index 9bed8dde..515d3efe 100644 --- a/mcp_server/README.md +++ b/mcp_server/README.md @@ -68,6 +68,7 @@ Available arguments: - `--transport`: Choose the transport method (sse or stdio, default: sse) - `--group-id`: Set a namespace for the graph (optional) - `--destroy-graph`: Destroy all Graphiti graphs (use with caution) +- `--use-custom-entities`: Enable entity extraction using the predefined ENTITY_TYPES ### Docker Deployment @@ -249,11 +250,41 @@ Assistant: I'll search for node summaries related to the company. [Assistant uses the search_nodes tool to find relevant entity summaries] ``` +## Integrating with the Cursor IDE + +To integrate the Graphiti MCP Server with the Cursor IDE, follow these steps: + +1. Run the Graphiti MCP server using the SSE transport: + +```bash +python graphiti_mcp_server.py --transport sse --use-custom-entities --group-id +``` + +Hint: specify a `group_id` to retain prior graph data. If you do not specify a `group_id`, the server will create a new graph + +2. Configure Cursor to connect to the Graphiti MCP server. + +```json +{ + "mcpServers": { + "Graphiti": { + "url": "http://localhost:8000/sse" + } + } +} +``` + +3. Add the Graphiti rules to Cursor's User Rules. See [cursor_rules.md](cursor_rules.md) for details. + +4. Kick off an agent session in Cursor. + +The integration enables AI assistants in Cursor to maintain persistent memory through Graphiti's knowledge graph capabilities. + ## Requirements - Python 3.10 or higher - Neo4j database (version 5.26 or later required) -- OpenAI API key (for LLM operations) +- OpenAI API key (for LLM operations and embeddings) - MCP-compatible client ## License diff --git a/mcp_server/cursor_rules.md b/mcp_server/cursor_rules.md new file mode 100644 index 00000000..1d296c21 --- /dev/null +++ b/mcp_server/cursor_rules.md @@ -0,0 +1,34 @@ +## Instructions for Using Graphiti's MCP Tools for Agent Memory + +### Before Starting Any Task + +- **Always search first:** Use the `search_nodes` tool to look for relevant preferences and procedures before beginning work. +- **Search for facts too:** Use the `search_facts` tool to discover relationships and factual information that may be relevant to your task. +- **Filter by entity type:** Specify `Preference`, `Procedure`, or `Requirement` in your node search to get targeted results. +- **Review all matches:** Carefully examine any preferences, procedures, or facts that match your current task. + +### Always Save New or Updated Information + +- **Capture requirements and preferences immediately:** When a user expresses a requirement or preference, use `add_episode` to store it right away. + - _Best practice:_ Split very long requirements into shorter, logical chunks. +- **Be explicit if something is an update to existing knowledge.** Only add what's changed or new to the graph. +- **Document procedures clearly:** When you discover how a user wants things done, record it as a procedure. +- **Record factual relationships:** When you learn about connections between entities, store these as facts. +- **Be specific with categories:** Label preferences and procedures with clear categories for better retrieval later. + +### During Your Work + +- **Respect discovered preferences:** Align your work with any preferences you've found. +- **Follow procedures exactly:** If you find a procedure for your current task, follow it step by step. +- **Apply relevant facts:** Use factual information to inform your decisions and recommendations. +- **Stay consistent:** Maintain consistency with previously identified preferences, procedures, and facts. + +### Best Practices + +- **Search before suggesting:** Always check if there's established knowledge before making recommendations. +- **Combine node and fact searches:** For complex tasks, search both nodes and facts to build a complete picture. +- **Use `center_node_uuid`:** When exploring related information, center your search around a specific node. +- **Prioritize specific matches:** More specific information takes precedence over general information. +- **Be proactive:** If you notice patterns in user behavior, consider storing them as preferences or procedures. + +**Remember:** The knowledge graph is your memory. Use it consistently to provide personalized assistance that respects the user's established preferences, procedures, and factual context. diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index 38d424c7..d20e3522 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -9,27 +9,110 @@ import logging import os import sys import uuid -from collections.abc import Awaitable, Callable from datetime import datetime, timezone -from typing import Any, Optional, TypedDict, TypeVar, Union, cast +from typing import Any, Optional, TypedDict, Union, cast from dotenv import load_dotenv from mcp.server.fastmcp import FastMCP -from pydantic import BaseModel +from pydantic import BaseModel, Field -# graphiti_core imports from graphiti_core import Graphiti from graphiti_core.edges import EntityEdge from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import LLMConfig from graphiti_core.llm_client.openai_client import OpenAIClient from graphiti_core.nodes import EpisodeType, EpisodicNode -from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF +from graphiti_core.search.search_config_recipes import ( + NODE_HYBRID_SEARCH_NODE_DISTANCE, + NODE_HYBRID_SEARCH_RRF, +) from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() +DEFAULT_LLM_MODEL = 'gpt-4o' + + +class Requirement(BaseModel): + """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill. + + Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the + edge that the requirement is a requirement. + + Instructions for identifying and extracting requirements: + 1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y") + 2. Identify functional specifications that describe what the system should do + 3. Pay attention to non-functional requirements like performance, security, or usability criteria + 4. Extract constraints or limitations that must be adhered to + 5. Focus on clear, specific, and measurable requirements rather than vague wishes + 6. Capture the priority or importance if mentioned ("critical", "high priority", etc.) + 7. Include any dependencies between requirements when explicitly stated + 8. Preserve the original intent and scope of the requirement + 9. Categorize requirements appropriately based on their domain or function + """ + + project_name: str = Field( + ..., + description='The name of the project to which the requirement belongs.', + ) + description: str = Field( + ..., + description='Description of the requirement. Only use information mentioned in the context to write this description.', + ) + + +class Preference(BaseModel): + """A Preference represents a user's expressed like, dislike, or preference for something. + + Instructions for identifying and extracting preferences: + 1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X" + 2. Pay attention to comparative statements ("I prefer X over Y") + 3. Consider the emotional tone when users mention certain topics + 4. Extract only preferences that are clearly expressed, not assumptions + 5. Categorize the preference appropriately based on its domain (food, music, brands, etc.) + 6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food") + 7. Only extract preferences directly stated by the user, not preferences of others they mention + 8. Provide a concise but specific description that captures the nature of the preference + """ + + category: str = Field( + ..., + description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')", + ) + description: str = Field( + ..., + description='Brief description of the preference. Only use information mentioned in the context to write this description.', + ) + + +class Procedure(BaseModel): + """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps. + + Instructions for identifying and extracting procedures: + 1. Look for sequential instructions or steps ("First do X, then do Y") + 2. Identify explicit directives or commands ("Always do X when Y happens") + 3. Pay attention to conditional statements ("If X occurs, then do Y") + 4. Extract procedures that have clear beginning and end points + 5. Focus on actionable instructions rather than general information + 6. Preserve the original sequence and dependencies between steps + 7. Include any specified conditions or triggers for the procedure + 8. Capture any stated purpose or goal of the procedure + 9. Summarize complex procedures while maintaining critical details + """ + + description: str = Field( + ..., + description='Brief description of the procedure. Only use information mentioned in the context to write this description.', + ) + + +ENTITY_TYPES: dict[str, BaseModel] = { + 'Requirement': Requirement, # type: ignore + 'Preference': Preference, # type: ignore + 'Procedure': Procedure, # type: ignore +} + # Type definitions for API responses class ErrorResponse(TypedDict): @@ -85,6 +168,7 @@ class GraphitiConfig(BaseModel): openai_base_url: Optional[str] = None model_name: Optional[str] = None group_id: Optional[str] = None + use_custom_entities: bool = False @classmethod def from_env(cls) -> 'GraphitiConfig': @@ -158,14 +242,6 @@ mcp = FastMCP( # Initialize Graphiti client graphiti_client: Optional[Graphiti] = None -# Type for functions that can be wrapped with graphiti_error_handler -T = TypeVar('T') -GraphitiFunc = Callable[..., Awaitable[T]] - - -# Note: We've removed the error handler decorator in favor of inline error handling -# This is to avoid type checking issues with the global graphiti_client variable - async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_graph: bool = False): """Initialize the Graphiti client with the provided settings. @@ -185,7 +261,6 @@ async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_gr if config.model_name: llm_config.model = config.model_name llm_client = OpenAIClient(config=llm_config) - logger.info('Using OpenAI as LLM client') else: raise ValueError('OPENAI_API_KEY must be set when not using a custom LLM client') @@ -219,16 +294,55 @@ def format_fact_result(edge: EntityEdge) -> dict[str, Any]: Returns: A dictionary representation of the edge with serialized dates and excluded embeddings """ - # Convert to dict using Pydantic's model_dump method with mode='json' - # This automatically handles datetime serialization and other complex types return edge.model_dump( - mode='json', # Properly handle datetime serialization for JSON + mode='json', exclude={ - 'fact_embedding', # Exclude embedding data + 'fact_embedding', }, ) +# Dictionary to store queues for each group_id +# Each queue is a list of tasks to be processed sequentially +episode_queues: dict[str, asyncio.Queue] = {} +# Dictionary to track if a worker is running for each group_id +queue_workers: dict[str, bool] = {} + + +async def process_episode_queue(group_id: str): + """Process episodes for a specific group_id sequentially. + + This function runs as a long-lived task that processes episodes + from the queue one at a time. + """ + global queue_workers + + logger.info(f'Starting episode queue worker for group_id: {group_id}') + queue_workers[group_id] = True + + try: + while True: + # Get the next episode processing function from the queue + # This will wait if the queue is empty + process_func = await episode_queues[group_id].get() + + try: + # Process the episode + await process_func() + except Exception as e: + logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}') + finally: + # Mark the task as done regardless of success/failure + episode_queues[group_id].task_done() + except asyncio.CancelledError: + logger.info(f'Episode queue worker for group_id {group_id} was cancelled') + except Exception as e: + logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}') + finally: + queue_workers[group_id] = False + logger.info(f'Stopped episode queue worker for group_id: {group_id}') + + @mcp.tool() async def add_episode( name: str, @@ -240,6 +354,9 @@ async def add_episode( ) -> Union[SuccessResponse, ErrorResponse]: """Add an episode to the Graphiti knowledge graph. This is the primary way to add information to the graph. + This function returns immediately and processes the episode addition in the background. + Episodes for the same group_id are processed sequentially to avoid race conditions. + Args: name (str): Name of the episode episode_body (str): The content of the episode. When source='json', this must be a properly escaped JSON string, @@ -265,7 +382,7 @@ async def add_episode( ) # Adding structured JSON data - # NOTE: episode_body must be a properly escaped JSON string + # NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes add_episode( name="Customer Profile", episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}", @@ -273,14 +390,6 @@ async def add_episode( source_description="CRM data" ) - # Adding more complex JSON with arrays and nested objects - add_episode( - name="Product Catalog", - episode_body='\{"catalog": \{"company": "Tech Solutions Inc.", "products": [\{"id": "P001", "name": "Product X", "features": ["Feature A", "Feature B"]\}]\}\}', - source="json", - source_description="Product catalog" - ) - # Adding message-style content add_episode( name="Customer Conversation", @@ -298,7 +407,7 @@ async def add_episode( - Entities will be created from appropriate JSON properties - Relationships between entities will be established based on the JSON structure """ - global graphiti_client + global graphiti_client, episode_queues, queue_workers if graphiti_client is None: return {'error': 'Graphiti client not initialized'} @@ -320,29 +429,59 @@ async def add_episode( # We've already checked that graphiti_client is not None above # This assert statement helps type checkers understand that graphiti_client is defined - # from this point forward in the function assert graphiti_client is not None, 'graphiti_client should not be None here' # Use cast to help the type checker understand that graphiti_client is not None - # This doesn't change the runtime behavior, only helps with static type checking client = cast(Graphiti, graphiti_client) - # Type checking will now know that client is a Graphiti instance (not None) - # and group_id is a str, not Optional[str] - await client.add_episode( - name=name, - episode_body=episode_body, - source=source_type, - source_description=source_description, - group_id=group_id_str, # Using the string version of group_id - uuid=uuid, - reference_time=datetime.now(timezone.utc), - ) - return {'message': f"Episode '{name}' added successfully"} + # Define the episode processing function + async def process_episode(): + try: + logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}") + # Use all entity types if use_custom_entities is enabled, otherwise use empty dict + entity_types = ENTITY_TYPES if config.use_custom_entities else {} + + await client.add_episode( + name=name, + episode_body=episode_body, + source=source_type, + source_description=source_description, + group_id=group_id_str, # Using the string version of group_id + uuid=uuid, + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + ) + logger.info(f"Episode '{name}' added successfully") + + logger.info(f"Building communities after episode '{name}'") + await client.build_communities() + + logger.info(f"Episode '{name}' processed successfully") + except Exception as e: + error_msg = str(e) + logger.error( + f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}" + ) + + # Initialize queue for this group_id if it doesn't exist + if group_id_str not in episode_queues: + episode_queues[group_id_str] = asyncio.Queue() + + # Add the episode processing function to the queue + await episode_queues[group_id_str].put(process_episode) + + # Start a worker for this queue if one isn't already running + if not queue_workers.get(group_id_str, False): + asyncio.create_task(process_episode_queue(group_id_str)) + + # Return immediately with a success message + return { + 'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})" + } except Exception as e: error_msg = str(e) - logger.error(f'Error adding episode: {error_msg}') - return {'error': f'Error adding episode: {error_msg}'} + logger.error(f'Error queuing episode task: {error_msg}') + return {'error': f'Error queuing episode task: {error_msg}'} @mcp.tool() @@ -351,15 +490,19 @@ async def search_nodes( group_ids: Optional[list[str]] = None, max_nodes: int = 10, center_node_uuid: Optional[str] = None, + entity: str = '', # cursor seems to break with None ) -> Union[NodeSearchResponse, ErrorResponse]: """Search the Graphiti knowledge graph for relevant node summaries. These contain a summary of all of a node's relationships with other nodes. + Note: entity is a single entity type to filter results (permitted: "Preference", "Procedure"). + Args: query: The search query group_ids: Optional list of group IDs to filter results max_nodes: Maximum number of nodes to return (default: 10) center_node_uuid: Optional UUID of a node to center the search around + entity: Optional single entity type to filter results (permitted: "Preference", "Procedure") """ global graphiti_client @@ -373,9 +516,16 @@ async def search_nodes( ) # Configure the search - search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) + if center_node_uuid is not None: + search_config = NODE_HYBRID_SEARCH_NODE_DISTANCE.model_copy(deep=True) + else: + search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) search_config.limit = max_nodes + filters = SearchFilters() + if entity != '': + filters.node_labels = [entity] + # We've already checked that graphiti_client is not None above assert graphiti_client is not None @@ -388,7 +538,7 @@ async def search_nodes( config=search_config, group_ids=effective_group_ids, center_node_uuid=center_node_uuid, - search_filter=SearchFilters(), + search_filter=filters, ) if not search_results.nodes: @@ -658,13 +808,10 @@ async def get_status() -> StatusResponse: } -def create_llm_client( - client_type: str = 'openai', api_key: Optional[str] = None, model: Optional[str] = None -) -> LLMClient: +def create_llm_client(api_key: Optional[str] = None, model: Optional[str] = None) -> LLMClient: """Create an OpenAI LLM client. Args: - client_type: Type of LLM client to create (only 'openai' or 'openai_generic' supported) api_key: API key for the OpenAI service model: Model name to use @@ -703,6 +850,11 @@ async def initialize_server() -> MCPConfig: # OpenAI is the only supported LLM client parser.add_argument('--model', help='Model name to use with the LLM client') parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs') + parser.add_argument( + '--use-custom-entities', + action='store_true', + help='Enable entity extraction using the predefined ENTITY_TYPES', + ) args = parser.parse_args() @@ -714,18 +866,23 @@ async def initialize_server() -> MCPConfig: config.group_id = f'graph_{uuid.uuid4().hex[:8]}' logger.info(f'Generated random group_id: {config.group_id}') + # Set use_custom_entities flag if specified + if args.use_custom_entities: + config.use_custom_entities = True + logger.info('Entity extraction enabled using predefined ENTITY_TYPES') + else: + logger.info('Entity extraction disabled (no custom entities will be used)') + llm_client = None # Create OpenAI client if model is specified or if OPENAI_API_KEY is available if args.model or config.openai_api_key: # Override model from command line if specified - if args.model: - config.model_name = args.model + + config.model_name = args.model or DEFAULT_LLM_MODEL # Create the OpenAI client - llm_client = create_llm_client( - client_type='openai', api_key=config.openai_api_key, model=config.model_name - ) + llm_client = create_llm_client(api_key=config.openai_api_key, model=config.model_name) # Initialize Graphiti with the specified LLM client await initialize_graphiti(llm_client, destroy_graph=args.destroy_graph) diff --git a/mcp_server/pyproject.toml b/mcp_server/pyproject.toml index bd66ea5a..d907fd33 100644 --- a/mcp_server/pyproject.toml +++ b/mcp_server/pyproject.toml @@ -7,5 +7,5 @@ requires-python = ">=3.10" dependencies = [ "mcp>=1.5.0", "openai>=1.68.2", - "graphiti-core>=0.8.1", + "graphiti-core>=0.8.2", ] diff --git a/mcp_server/uv.lock b/mcp_server/uv.lock index 7987a72d..c40fb8b8 100644 --- a/mcp_server/uv.lock +++ b/mcp_server/uv.lock @@ -10,24 +10,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, ] -[[package]] -name = "anthropic" -version = "0.49.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "jiter" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/86/e3/a88c8494ce4d1a88252b9e053607e885f9b14d0a32273d47b727cbee4228/anthropic-0.49.0.tar.gz", hash = "sha256:c09e885b0f674b9119b4f296d8508907f6cff0009bc20d5cf6b35936c40b4398", size = 210016 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/74/5d90ad14d55fbe3f9c474fdcb6e34b4bed99e3be8efac98734a5ddce88c1/anthropic-0.49.0-py3-none-any.whl", hash = "sha256:bbc17ad4e7094988d2fa86b87753ded8dce12498f4b85fe5810f208f454a8375", size = 243368 }, -] - [[package]] name = "anyio" version = "4.9.0" @@ -102,12 +84,10 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.8.1" -source = { directory = "../" } +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "anthropic" }, { name = "diskcache" }, - { name = "mcp" }, { name = "neo4j" }, { name = "numpy" }, { name = "openai" }, @@ -115,18 +95,9 @@ dependencies = [ { name = "python-dotenv" }, { name = "tenacity" }, ] - -[package.metadata] -requires-dist = [ - { name = "anthropic", specifier = ">=0.49.0,<0.50.0" }, - { name = "diskcache", specifier = ">=5.6.3,<6.0.0" }, - { name = "mcp", specifier = ">=1.5.0,<2.0.0" }, - { name = "neo4j", specifier = ">=5.23.0,<6.0.0" }, - { name = "numpy", specifier = ">=1.0.0" }, - { name = "openai", specifier = ">=1.53.0,<2.0.0" }, - { name = "pydantic", specifier = ">=2.8.2,<3.0.0" }, - { name = "python-dotenv", specifier = ">=1.0.1,<2.0.0" }, - { name = "tenacity", specifier = "==9.0.0" }, +sdist = { url = "https://files.pythonhosted.org/packages/b7/29/f0d74adc687514a226a164340a7fa9254f9cffa2ff7ab353fc1edf4016cb/graphiti_core-0.8.2.tar.gz", hash = "sha256:477a9172728b92ba1875222f8d9c6337d9fa465363b75afd67db9a159c39d379", size = 62464 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/4b/d2065544beb24402bb41197b2adf7ed7af287453c82290d1ca757532f6d3/graphiti_core-0.8.2-py3-none-any.whl", hash = "sha256:ce992f9f3dee168d3ca442e0d0fd685a1a66dee4df94c47909037295aceac45e", size = 95052 }, ] [[package]] @@ -274,7 +245,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "graphiti-core", directory = "../" }, + { name = "graphiti-core", specifier = ">=0.8.2" }, { name = "mcp", specifier = ">=1.5.0" }, { name = "openai", specifier = ">=1.68.2" }, ]