From 6696a05b50542c686ea390d8612d151fa534705c Mon Sep 17 00:00:00 2001 From: Lars Varming Date: Sat, 8 Nov 2025 20:36:11 +0100 Subject: [PATCH] Add get_entities_by_type and compare_facts_over_time MCP tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - get_entities_by_type: Retrieve entities by type classification (essential for PKM) - compare_facts_over_time: Compare facts between time periods to track knowledge evolution - Enhanced add_memory UUID documentation to prevent LLM misuse Both tools use only public Graphiti APIs for upstream compatibility. Follows MCP specification best practices. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- mcp_server/src/graphiti_mcp_server.py | 349 +++++++++++++++++++++++--- 1 file changed, 320 insertions(+), 29 deletions(-) diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index 0c9a568a..b276f6ec 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -245,35 +245,35 @@ class GraphitiService: db_provider = self.config.database.provider if db_provider.lower() == 'falkordb': raise RuntimeError( - f"\n{'='*70}\n" - f"Database Connection Error: FalkorDB is not running\n" - f"{'='*70}\n\n" - f"FalkorDB at {db_config['host']}:{db_config['port']} is not accessible.\n\n" - f"To start FalkorDB:\n" - f" - Using Docker Compose: cd mcp_server && docker compose up\n" - f" - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n" - f"{'='*70}\n" + f'\n{"=" * 70}\n' + f'Database Connection Error: FalkorDB is not running\n' + f'{"=" * 70}\n\n' + f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n' + f'To start FalkorDB:\n' + f' - Using Docker Compose: cd mcp_server && docker compose up\n' + f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n' + f'{"=" * 70}\n' ) from db_error elif db_provider.lower() == 'neo4j': raise RuntimeError( - f"\n{'='*70}\n" - f"Database Connection Error: Neo4j is not running\n" - f"{'='*70}\n\n" - f"Neo4j at {db_config.get('uri', 'unknown')} is not accessible.\n\n" - f"To start Neo4j:\n" - f" - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n" - f" - Or install Neo4j Desktop from: https://neo4j.com/download/\n" - f" - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n" - f"{'='*70}\n" + f'\n{"=" * 70}\n' + f'Database Connection Error: Neo4j is not running\n' + f'{"=" * 70}\n\n' + f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n' + f'To start Neo4j:\n' + f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n' + f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n' + f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n' + f'{"=" * 70}\n' ) from db_error else: raise RuntimeError( - f"\n{'='*70}\n" - f"Database Connection Error: {db_provider} is not running\n" - f"{'='*70}\n\n" - f"{db_provider} at {db_config.get('uri', 'unknown')} is not accessible.\n\n" - f"Please ensure {db_provider} is running and accessible.\n\n" - f"{'='*70}\n" + f'\n{"=" * 70}\n' + f'Database Connection Error: {db_provider} is not running\n' + f'{"=" * 70}\n\n' + f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n' + f'Please ensure {db_provider} is running and accessible.\n\n' + f'{"=" * 70}\n' ) from db_error # Re-raise other errors raise @@ -344,20 +344,20 @@ async def add_memory( - 'json': For structured data - 'message': For conversation-style content source_description (str, optional): Description of the source - uuid (str, optional): Optional UUID for the episode + uuid (str, optional): NEVER provide a UUID for new episodes - UUIDs are auto-generated. This parameter + can ONLY be used for updating an existing episode by providing its existing UUID. + Providing a UUID will update/replace the episode with that UUID if it exists. Examples: - # Adding plain text content + # Adding plain text content (NEW episode - no UUID) add_memory( name="Company News", episode_body="Acme Corp announced a new product line today.", source="text", - source_description="news article", - group_id="some_arbitrary_string" + source_description="news article" ) - # Adding structured JSON data - # NOTE: episode_body should be a JSON string (standard JSON escaping) + # Adding structured JSON data (NEW episode - no UUID) add_memory( name="Customer Profile", episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}', @@ -484,6 +484,106 @@ async def search_nodes( return ErrorResponse(error=f'Error searching nodes: {error_msg}') +@mcp.tool() +async def get_entities_by_type( + entity_types: list[str], + group_ids: list[str] | None = None, + max_entities: int = 20, + query: str | None = None, +) -> NodeSearchResponse | ErrorResponse: + """Retrieve entities by their type classification. + + Useful for browsing entities by type (e.g., Pattern, Insight, Preference) + in personal knowledge management workflows. + + Args: + entity_types: List of entity type names to retrieve (e.g., ["Pattern", "Insight"]) + group_ids: Optional list of group IDs to filter results + max_entities: Maximum number of entities to return (default: 20) + query: Optional search query to filter entities + + Examples: + # Get all preferences + get_entities_by_type(entity_types=["Preference"]) + + # Get insights and patterns related to productivity + get_entities_by_type( + entity_types=["Insight", "Pattern"], + query="productivity" + ) + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + # Validate entity_types parameter + if not entity_types or len(entity_types) == 0: + return ErrorResponse(error='entity_types cannot be empty') + + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config + effective_group_ids = ( + group_ids + if group_ids is not None + else [config.graphiti.group_id] + if config.graphiti.group_id + else [] + ) + + # Create search filters with entity type labels + search_filters = SearchFilters(node_labels=entity_types) + + # Use the search_ method with node search config + from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + + # Use query if provided, otherwise use a generic query to get all of the type + search_query = query if query else ' ' + + results = await client.search_( + query=search_query, + config=NODE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=search_filters, + ) + + # Extract nodes from results + nodes = results.nodes[:max_entities] if results.nodes else [] + + if not nodes: + return NodeSearchResponse( + message=f'No entities found with types: {", ".join(entity_types)}', nodes=[] + ) + + # Format the results (same as search_nodes) + node_results = [] + for node in nodes: + # Get attributes and ensure no embeddings are included + attrs = node.attributes if hasattr(node, 'attributes') else {} + # Remove any embedding keys that might be in attributes + attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()} + + node_results.append( + NodeResult( + uuid=node.uuid, + name=node.name, + labels=node.labels if node.labels else [], + created_at=node.created_at.isoformat() if node.created_at else None, + summary=node.summary, + group_id=node.group_id, + attributes=attrs, + ) + ) + + return NodeSearchResponse(message=f'Found {len(node_results)} entities', nodes=node_results) + except Exception as e: + error_msg = str(e) + logger.error(f'Error getting entities by type: {error_msg}') + return ErrorResponse(error=f'Error getting entities by type: {error_msg}') + + @mcp.tool() async def search_memory_facts( query: str, @@ -538,6 +638,197 @@ async def search_memory_facts( return ErrorResponse(error=f'Error searching facts: {error_msg}') +@mcp.tool() +async def compare_facts_over_time( + query: str, + start_time: str, + end_time: str, + group_ids: list[str] | None = None, + max_facts_per_period: int = 10, +) -> dict[str, Any] | ErrorResponse: + """Compare facts between two time periods. + + Track how understanding evolved by comparing facts valid at different times. + Returns facts at start, facts at end, facts invalidated, and facts added. + + Args: + query: The search query + start_time: Start timestamp in ISO 8601 format (e.g., "2024-01-01" or "2024-01-01T10:30:00Z") + end_time: End timestamp in ISO 8601 format + group_ids: Optional list of group IDs to filter results + max_facts_per_period: Maximum number of facts to return per period (default: 10) + + Examples: + # Track how understanding evolved + compare_facts_over_time( + query="productivity patterns", + start_time="2024-01-01", + end_time="2024-03-01" + ) + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + # Validate inputs + if not query or not query.strip(): + return ErrorResponse(error='query cannot be empty') + if not start_time or not start_time.strip(): + return ErrorResponse(error='start_time cannot be empty') + if not end_time or not end_time.strip(): + return ErrorResponse(error='end_time cannot be empty') + if max_facts_per_period <= 0: + return ErrorResponse(error='max_facts_per_period must be a positive integer') + + # Parse timestamps + from datetime import datetime + + from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF + from graphiti_core.search.search_filters import ComparisonOperator, DateFilter + + try: + start_dt = datetime.fromisoformat(start_time.replace('Z', '+00:00')) + end_dt = datetime.fromisoformat(end_time.replace('Z', '+00:00')) + except ValueError as e: + return ErrorResponse( + error=f'Invalid timestamp format: {e}. Use ISO 8601 (e.g., "2024-03-15T10:30:00Z" or "2024-03-15")' + ) + + if start_dt >= end_dt: + return ErrorResponse(error='start_time must be before end_time') + + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config + effective_group_ids = ( + group_ids + if group_ids is not None + else [config.graphiti.group_id] + if config.graphiti.group_id + else [] + ) + + # Query 1: Facts valid at start_time + # valid_at <= start_time AND (invalid_at > start_time OR invalid_at IS NULL) + start_filters = SearchFilters( + valid_at=[ + [DateFilter(date=start_dt, comparison_operator=ComparisonOperator.less_than_equal)] + ], + invalid_at=[ + [DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than)], + [DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)], + ], + ) + + start_results = await client.search_( + query=query, + config=EDGE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=start_filters, + ) + + # Query 2: Facts valid at end_time + end_filters = SearchFilters( + valid_at=[ + [DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal)] + ], + invalid_at=[ + [DateFilter(date=end_dt, comparison_operator=ComparisonOperator.greater_than)], + [DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)], + ], + ) + + end_results = await client.search_( + query=query, + config=EDGE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=end_filters, + ) + + # Query 3: Facts invalidated between start and end + # invalid_at > start_time AND invalid_at <= end_time + invalidated_filters = SearchFilters( + invalid_at=[ + [ + DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than), + DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal), + ] + ], + ) + + invalidated_results = await client.search_( + query=query, + config=EDGE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=invalidated_filters, + ) + + # Query 4: Facts added between start and end + # created_at > start_time AND created_at <= end_time + added_filters = SearchFilters( + created_at=[ + [ + DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than), + DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal), + ] + ], + ) + + added_results = await client.search_( + query=query, + config=EDGE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=added_filters, + ) + + # Format results + facts_from_start = [ + format_fact_result(edge) + for edge in (start_results.edges[:max_facts_per_period] if start_results.edges else []) + ] + + facts_at_end = [ + format_fact_result(edge) + for edge in (end_results.edges[:max_facts_per_period] if end_results.edges else []) + ] + + facts_invalidated = [ + format_fact_result(edge) + for edge in ( + invalidated_results.edges[:max_facts_per_period] + if invalidated_results.edges + else [] + ) + ] + + facts_added = [ + format_fact_result(edge) + for edge in (added_results.edges[:max_facts_per_period] if added_results.edges else []) + ] + + return { + 'message': f'Comparison completed between {start_time} and {end_time}', + 'start_time': start_time, + 'end_time': end_time, + 'summary': { + 'facts_at_start_count': len(facts_from_start), + 'facts_at_end_count': len(facts_at_end), + 'facts_invalidated_count': len(facts_invalidated), + 'facts_added_count': len(facts_added), + }, + 'facts_from_start': facts_from_start, + 'facts_at_end': facts_at_end, + 'facts_invalidated': facts_invalidated, + 'facts_added': facts_added, + } + except Exception as e: + error_msg = str(e) + logger.error(f'Error comparing facts over time: {error_msg}') + return ErrorResponse(error=f'Error comparing facts over time: {error_msg}') + + @mcp.tool() async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: """Delete an entity edge from the graph memory.