Add get_entities_by_type and compare_facts_over_time MCP tools

- get_entities_by_type: Retrieve entities by type classification (essential for PKM)
- compare_facts_over_time: Compare facts between time periods to track knowledge evolution
- Enhanced add_memory UUID documentation to prevent LLM misuse

Both tools use only public Graphiti APIs for upstream compatibility.
Follows MCP specification best practices.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Lars Varming 2025-11-08 20:36:11 +01:00
parent 674619cc89
commit 6696a05b50

View file

@ -245,35 +245,35 @@ class GraphitiService:
db_provider = self.config.database.provider
if db_provider.lower() == 'falkordb':
raise RuntimeError(
f"\n{'='*70}\n"
f"Database Connection Error: FalkorDB is not running\n"
f"{'='*70}\n\n"
f"FalkorDB at {db_config['host']}:{db_config['port']} is not accessible.\n\n"
f"To start FalkorDB:\n"
f" - Using Docker Compose: cd mcp_server && docker compose up\n"
f" - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n"
f"{'='*70}\n"
f'\n{"=" * 70}\n'
f'Database Connection Error: FalkorDB is not running\n'
f'{"=" * 70}\n\n'
f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
f'To start FalkorDB:\n'
f' - Using Docker Compose: cd mcp_server && docker compose up\n'
f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
f'{"=" * 70}\n'
) from db_error
elif db_provider.lower() == 'neo4j':
raise RuntimeError(
f"\n{'='*70}\n"
f"Database Connection Error: Neo4j is not running\n"
f"{'='*70}\n\n"
f"Neo4j at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
f"To start Neo4j:\n"
f" - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n"
f" - Or install Neo4j Desktop from: https://neo4j.com/download/\n"
f" - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n"
f"{'='*70}\n"
f'\n{"=" * 70}\n'
f'Database Connection Error: Neo4j is not running\n'
f'{"=" * 70}\n\n'
f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f'To start Neo4j:\n'
f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
f'{"=" * 70}\n'
) from db_error
else:
raise RuntimeError(
f"\n{'='*70}\n"
f"Database Connection Error: {db_provider} is not running\n"
f"{'='*70}\n\n"
f"{db_provider} at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
f"Please ensure {db_provider} is running and accessible.\n\n"
f"{'='*70}\n"
f'\n{"=" * 70}\n'
f'Database Connection Error: {db_provider} is not running\n'
f'{"=" * 70}\n\n'
f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f'Please ensure {db_provider} is running and accessible.\n\n'
f'{"=" * 70}\n'
) from db_error
# Re-raise other errors
raise
@ -344,20 +344,20 @@ async def add_memory(
- 'json': For structured data
- 'message': For conversation-style content
source_description (str, optional): Description of the source
uuid (str, optional): Optional UUID for the episode
uuid (str, optional): NEVER provide a UUID for new episodes - UUIDs are auto-generated. This parameter
can ONLY be used for updating an existing episode by providing its existing UUID.
Providing a UUID will update/replace the episode with that UUID if it exists.
Examples:
# Adding plain text content
# Adding plain text content (NEW episode - no UUID)
add_memory(
name="Company News",
episode_body="Acme Corp announced a new product line today.",
source="text",
source_description="news article",
group_id="some_arbitrary_string"
source_description="news article"
)
# Adding structured JSON data
# NOTE: episode_body should be a JSON string (standard JSON escaping)
# Adding structured JSON data (NEW episode - no UUID)
add_memory(
name="Customer Profile",
episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}',
@ -484,6 +484,106 @@ async def search_nodes(
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
@mcp.tool()
async def get_entities_by_type(
entity_types: list[str],
group_ids: list[str] | None = None,
max_entities: int = 20,
query: str | None = None,
) -> NodeSearchResponse | ErrorResponse:
"""Retrieve entities by their type classification.
Useful for browsing entities by type (e.g., Pattern, Insight, Preference)
in personal knowledge management workflows.
Args:
entity_types: List of entity type names to retrieve (e.g., ["Pattern", "Insight"])
group_ids: Optional list of group IDs to filter results
max_entities: Maximum number of entities to return (default: 20)
query: Optional search query to filter entities
Examples:
# Get all preferences
get_entities_by_type(entity_types=["Preference"])
# Get insights and patterns related to productivity
get_entities_by_type(
entity_types=["Insight", "Pattern"],
query="productivity"
)
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
# Validate entity_types parameter
if not entity_types or len(entity_types) == 0:
return ErrorResponse(error='entity_types cannot be empty')
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config
effective_group_ids = (
group_ids
if group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
)
# Create search filters with entity type labels
search_filters = SearchFilters(node_labels=entity_types)
# Use the search_ method with node search config
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
# Use query if provided, otherwise use a generic query to get all of the type
search_query = query if query else ' '
results = await client.search_(
query=search_query,
config=NODE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=search_filters,
)
# Extract nodes from results
nodes = results.nodes[:max_entities] if results.nodes else []
if not nodes:
return NodeSearchResponse(
message=f'No entities found with types: {", ".join(entity_types)}', nodes=[]
)
# Format the results (same as search_nodes)
node_results = []
for node in nodes:
# Get attributes and ensure no embeddings are included
attrs = node.attributes if hasattr(node, 'attributes') else {}
# Remove any embedding keys that might be in attributes
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
node_results.append(
NodeResult(
uuid=node.uuid,
name=node.name,
labels=node.labels if node.labels else [],
created_at=node.created_at.isoformat() if node.created_at else None,
summary=node.summary,
group_id=node.group_id,
attributes=attrs,
)
)
return NodeSearchResponse(message=f'Found {len(node_results)} entities', nodes=node_results)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting entities by type: {error_msg}')
return ErrorResponse(error=f'Error getting entities by type: {error_msg}')
@mcp.tool()
async def search_memory_facts(
query: str,
@ -538,6 +638,197 @@ async def search_memory_facts(
return ErrorResponse(error=f'Error searching facts: {error_msg}')
@mcp.tool()
async def compare_facts_over_time(
query: str,
start_time: str,
end_time: str,
group_ids: list[str] | None = None,
max_facts_per_period: int = 10,
) -> dict[str, Any] | ErrorResponse:
"""Compare facts between two time periods.
Track how understanding evolved by comparing facts valid at different times.
Returns facts at start, facts at end, facts invalidated, and facts added.
Args:
query: The search query
start_time: Start timestamp in ISO 8601 format (e.g., "2024-01-01" or "2024-01-01T10:30:00Z")
end_time: End timestamp in ISO 8601 format
group_ids: Optional list of group IDs to filter results
max_facts_per_period: Maximum number of facts to return per period (default: 10)
Examples:
# Track how understanding evolved
compare_facts_over_time(
query="productivity patterns",
start_time="2024-01-01",
end_time="2024-03-01"
)
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
# Validate inputs
if not query or not query.strip():
return ErrorResponse(error='query cannot be empty')
if not start_time or not start_time.strip():
return ErrorResponse(error='start_time cannot be empty')
if not end_time or not end_time.strip():
return ErrorResponse(error='end_time cannot be empty')
if max_facts_per_period <= 0:
return ErrorResponse(error='max_facts_per_period must be a positive integer')
# Parse timestamps
from datetime import datetime
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter
try:
start_dt = datetime.fromisoformat(start_time.replace('Z', '+00:00'))
end_dt = datetime.fromisoformat(end_time.replace('Z', '+00:00'))
except ValueError as e:
return ErrorResponse(
error=f'Invalid timestamp format: {e}. Use ISO 8601 (e.g., "2024-03-15T10:30:00Z" or "2024-03-15")'
)
if start_dt >= end_dt:
return ErrorResponse(error='start_time must be before end_time')
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config
effective_group_ids = (
group_ids
if group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
)
# Query 1: Facts valid at start_time
# valid_at <= start_time AND (invalid_at > start_time OR invalid_at IS NULL)
start_filters = SearchFilters(
valid_at=[
[DateFilter(date=start_dt, comparison_operator=ComparisonOperator.less_than_equal)]
],
invalid_at=[
[DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than)],
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
],
)
start_results = await client.search_(
query=query,
config=EDGE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=start_filters,
)
# Query 2: Facts valid at end_time
end_filters = SearchFilters(
valid_at=[
[DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal)]
],
invalid_at=[
[DateFilter(date=end_dt, comparison_operator=ComparisonOperator.greater_than)],
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
],
)
end_results = await client.search_(
query=query,
config=EDGE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=end_filters,
)
# Query 3: Facts invalidated between start and end
# invalid_at > start_time AND invalid_at <= end_time
invalidated_filters = SearchFilters(
invalid_at=[
[
DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than),
DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal),
]
],
)
invalidated_results = await client.search_(
query=query,
config=EDGE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=invalidated_filters,
)
# Query 4: Facts added between start and end
# created_at > start_time AND created_at <= end_time
added_filters = SearchFilters(
created_at=[
[
DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than),
DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal),
]
],
)
added_results = await client.search_(
query=query,
config=EDGE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=added_filters,
)
# Format results
facts_from_start = [
format_fact_result(edge)
for edge in (start_results.edges[:max_facts_per_period] if start_results.edges else [])
]
facts_at_end = [
format_fact_result(edge)
for edge in (end_results.edges[:max_facts_per_period] if end_results.edges else [])
]
facts_invalidated = [
format_fact_result(edge)
for edge in (
invalidated_results.edges[:max_facts_per_period]
if invalidated_results.edges
else []
)
]
facts_added = [
format_fact_result(edge)
for edge in (added_results.edges[:max_facts_per_period] if added_results.edges else [])
]
return {
'message': f'Comparison completed between {start_time} and {end_time}',
'start_time': start_time,
'end_time': end_time,
'summary': {
'facts_at_start_count': len(facts_from_start),
'facts_at_end_count': len(facts_at_end),
'facts_invalidated_count': len(facts_invalidated),
'facts_added_count': len(facts_added),
},
'facts_from_start': facts_from_start,
'facts_at_end': facts_at_end,
'facts_invalidated': facts_invalidated,
'facts_added': facts_added,
}
except Exception as e:
error_msg = str(e)
logger.error(f'Error comparing facts over time: {error_msg}')
return ErrorResponse(error=f'Error comparing facts over time: {error_msg}')
@mcp.tool()
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an entity edge from the graph memory.