Add get_entities_by_type and compare_facts_over_time MCP tools
- get_entities_by_type: Retrieve entities by type classification (essential for PKM) - compare_facts_over_time: Compare facts between time periods to track knowledge evolution - Enhanced add_memory UUID documentation to prevent LLM misuse Both tools use only public Graphiti APIs for upstream compatibility. Follows MCP specification best practices. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
674619cc89
commit
6696a05b50
1 changed files with 320 additions and 29 deletions
|
|
@ -245,35 +245,35 @@ class GraphitiService:
|
|||
db_provider = self.config.database.provider
|
||||
if db_provider.lower() == 'falkordb':
|
||||
raise RuntimeError(
|
||||
f"\n{'='*70}\n"
|
||||
f"Database Connection Error: FalkorDB is not running\n"
|
||||
f"{'='*70}\n\n"
|
||||
f"FalkorDB at {db_config['host']}:{db_config['port']} is not accessible.\n\n"
|
||||
f"To start FalkorDB:\n"
|
||||
f" - Using Docker Compose: cd mcp_server && docker compose up\n"
|
||||
f" - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n"
|
||||
f"{'='*70}\n"
|
||||
f'\n{"=" * 70}\n'
|
||||
f'Database Connection Error: FalkorDB is not running\n'
|
||||
f'{"=" * 70}\n\n'
|
||||
f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
|
||||
f'To start FalkorDB:\n'
|
||||
f' - Using Docker Compose: cd mcp_server && docker compose up\n'
|
||||
f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
|
||||
f'{"=" * 70}\n'
|
||||
) from db_error
|
||||
elif db_provider.lower() == 'neo4j':
|
||||
raise RuntimeError(
|
||||
f"\n{'='*70}\n"
|
||||
f"Database Connection Error: Neo4j is not running\n"
|
||||
f"{'='*70}\n\n"
|
||||
f"Neo4j at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
|
||||
f"To start Neo4j:\n"
|
||||
f" - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n"
|
||||
f" - Or install Neo4j Desktop from: https://neo4j.com/download/\n"
|
||||
f" - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n"
|
||||
f"{'='*70}\n"
|
||||
f'\n{"=" * 70}\n'
|
||||
f'Database Connection Error: Neo4j is not running\n'
|
||||
f'{"=" * 70}\n\n'
|
||||
f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
||||
f'To start Neo4j:\n'
|
||||
f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
|
||||
f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
|
||||
f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
|
||||
f'{"=" * 70}\n'
|
||||
) from db_error
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"\n{'='*70}\n"
|
||||
f"Database Connection Error: {db_provider} is not running\n"
|
||||
f"{'='*70}\n\n"
|
||||
f"{db_provider} at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
|
||||
f"Please ensure {db_provider} is running and accessible.\n\n"
|
||||
f"{'='*70}\n"
|
||||
f'\n{"=" * 70}\n'
|
||||
f'Database Connection Error: {db_provider} is not running\n'
|
||||
f'{"=" * 70}\n\n'
|
||||
f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
||||
f'Please ensure {db_provider} is running and accessible.\n\n'
|
||||
f'{"=" * 70}\n'
|
||||
) from db_error
|
||||
# Re-raise other errors
|
||||
raise
|
||||
|
|
@ -344,20 +344,20 @@ async def add_memory(
|
|||
- 'json': For structured data
|
||||
- 'message': For conversation-style content
|
||||
source_description (str, optional): Description of the source
|
||||
uuid (str, optional): Optional UUID for the episode
|
||||
uuid (str, optional): NEVER provide a UUID for new episodes - UUIDs are auto-generated. This parameter
|
||||
can ONLY be used for updating an existing episode by providing its existing UUID.
|
||||
Providing a UUID will update/replace the episode with that UUID if it exists.
|
||||
|
||||
Examples:
|
||||
# Adding plain text content
|
||||
# Adding plain text content (NEW episode - no UUID)
|
||||
add_memory(
|
||||
name="Company News",
|
||||
episode_body="Acme Corp announced a new product line today.",
|
||||
source="text",
|
||||
source_description="news article",
|
||||
group_id="some_arbitrary_string"
|
||||
source_description="news article"
|
||||
)
|
||||
|
||||
# Adding structured JSON data
|
||||
# NOTE: episode_body should be a JSON string (standard JSON escaping)
|
||||
# Adding structured JSON data (NEW episode - no UUID)
|
||||
add_memory(
|
||||
name="Customer Profile",
|
||||
episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}',
|
||||
|
|
@ -484,6 +484,106 @@ async def search_nodes(
|
|||
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_entities_by_type(
|
||||
entity_types: list[str],
|
||||
group_ids: list[str] | None = None,
|
||||
max_entities: int = 20,
|
||||
query: str | None = None,
|
||||
) -> NodeSearchResponse | ErrorResponse:
|
||||
"""Retrieve entities by their type classification.
|
||||
|
||||
Useful for browsing entities by type (e.g., Pattern, Insight, Preference)
|
||||
in personal knowledge management workflows.
|
||||
|
||||
Args:
|
||||
entity_types: List of entity type names to retrieve (e.g., ["Pattern", "Insight"])
|
||||
group_ids: Optional list of group IDs to filter results
|
||||
max_entities: Maximum number of entities to return (default: 20)
|
||||
query: Optional search query to filter entities
|
||||
|
||||
Examples:
|
||||
# Get all preferences
|
||||
get_entities_by_type(entity_types=["Preference"])
|
||||
|
||||
# Get insights and patterns related to productivity
|
||||
get_entities_by_type(
|
||||
entity_types=["Insight", "Pattern"],
|
||||
query="productivity"
|
||||
)
|
||||
"""
|
||||
global graphiti_service
|
||||
|
||||
if graphiti_service is None:
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
try:
|
||||
# Validate entity_types parameter
|
||||
if not entity_types or len(entity_types) == 0:
|
||||
return ErrorResponse(error='entity_types cannot be empty')
|
||||
|
||||
client = await graphiti_service.get_client()
|
||||
|
||||
# Use the provided group_ids or fall back to the default from config
|
||||
effective_group_ids = (
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [config.graphiti.group_id]
|
||||
if config.graphiti.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
# Create search filters with entity type labels
|
||||
search_filters = SearchFilters(node_labels=entity_types)
|
||||
|
||||
# Use the search_ method with node search config
|
||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
|
||||
# Use query if provided, otherwise use a generic query to get all of the type
|
||||
search_query = query if query else ' '
|
||||
|
||||
results = await client.search_(
|
||||
query=search_query,
|
||||
config=NODE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=search_filters,
|
||||
)
|
||||
|
||||
# Extract nodes from results
|
||||
nodes = results.nodes[:max_entities] if results.nodes else []
|
||||
|
||||
if not nodes:
|
||||
return NodeSearchResponse(
|
||||
message=f'No entities found with types: {", ".join(entity_types)}', nodes=[]
|
||||
)
|
||||
|
||||
# Format the results (same as search_nodes)
|
||||
node_results = []
|
||||
for node in nodes:
|
||||
# Get attributes and ensure no embeddings are included
|
||||
attrs = node.attributes if hasattr(node, 'attributes') else {}
|
||||
# Remove any embedding keys that might be in attributes
|
||||
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
|
||||
|
||||
node_results.append(
|
||||
NodeResult(
|
||||
uuid=node.uuid,
|
||||
name=node.name,
|
||||
labels=node.labels if node.labels else [],
|
||||
created_at=node.created_at.isoformat() if node.created_at else None,
|
||||
summary=node.summary,
|
||||
group_id=node.group_id,
|
||||
attributes=attrs,
|
||||
)
|
||||
)
|
||||
|
||||
return NodeSearchResponse(message=f'Found {len(node_results)} entities', nodes=node_results)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting entities by type: {error_msg}')
|
||||
return ErrorResponse(error=f'Error getting entities by type: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_memory_facts(
|
||||
query: str,
|
||||
|
|
@ -538,6 +638,197 @@ async def search_memory_facts(
|
|||
return ErrorResponse(error=f'Error searching facts: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def compare_facts_over_time(
|
||||
query: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
group_ids: list[str] | None = None,
|
||||
max_facts_per_period: int = 10,
|
||||
) -> dict[str, Any] | ErrorResponse:
|
||||
"""Compare facts between two time periods.
|
||||
|
||||
Track how understanding evolved by comparing facts valid at different times.
|
||||
Returns facts at start, facts at end, facts invalidated, and facts added.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
start_time: Start timestamp in ISO 8601 format (e.g., "2024-01-01" or "2024-01-01T10:30:00Z")
|
||||
end_time: End timestamp in ISO 8601 format
|
||||
group_ids: Optional list of group IDs to filter results
|
||||
max_facts_per_period: Maximum number of facts to return per period (default: 10)
|
||||
|
||||
Examples:
|
||||
# Track how understanding evolved
|
||||
compare_facts_over_time(
|
||||
query="productivity patterns",
|
||||
start_time="2024-01-01",
|
||||
end_time="2024-03-01"
|
||||
)
|
||||
"""
|
||||
global graphiti_service
|
||||
|
||||
if graphiti_service is None:
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
if not query or not query.strip():
|
||||
return ErrorResponse(error='query cannot be empty')
|
||||
if not start_time or not start_time.strip():
|
||||
return ErrorResponse(error='start_time cannot be empty')
|
||||
if not end_time or not end_time.strip():
|
||||
return ErrorResponse(error='end_time cannot be empty')
|
||||
if max_facts_per_period <= 0:
|
||||
return ErrorResponse(error='max_facts_per_period must be a positive integer')
|
||||
|
||||
# Parse timestamps
|
||||
from datetime import datetime
|
||||
|
||||
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter
|
||||
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_time.replace('Z', '+00:00'))
|
||||
end_dt = datetime.fromisoformat(end_time.replace('Z', '+00:00'))
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
error=f'Invalid timestamp format: {e}. Use ISO 8601 (e.g., "2024-03-15T10:30:00Z" or "2024-03-15")'
|
||||
)
|
||||
|
||||
if start_dt >= end_dt:
|
||||
return ErrorResponse(error='start_time must be before end_time')
|
||||
|
||||
client = await graphiti_service.get_client()
|
||||
|
||||
# Use the provided group_ids or fall back to the default from config
|
||||
effective_group_ids = (
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [config.graphiti.group_id]
|
||||
if config.graphiti.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
# Query 1: Facts valid at start_time
|
||||
# valid_at <= start_time AND (invalid_at > start_time OR invalid_at IS NULL)
|
||||
start_filters = SearchFilters(
|
||||
valid_at=[
|
||||
[DateFilter(date=start_dt, comparison_operator=ComparisonOperator.less_than_equal)]
|
||||
],
|
||||
invalid_at=[
|
||||
[DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than)],
|
||||
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
|
||||
],
|
||||
)
|
||||
|
||||
start_results = await client.search_(
|
||||
query=query,
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=start_filters,
|
||||
)
|
||||
|
||||
# Query 2: Facts valid at end_time
|
||||
end_filters = SearchFilters(
|
||||
valid_at=[
|
||||
[DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal)]
|
||||
],
|
||||
invalid_at=[
|
||||
[DateFilter(date=end_dt, comparison_operator=ComparisonOperator.greater_than)],
|
||||
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
|
||||
],
|
||||
)
|
||||
|
||||
end_results = await client.search_(
|
||||
query=query,
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=end_filters,
|
||||
)
|
||||
|
||||
# Query 3: Facts invalidated between start and end
|
||||
# invalid_at > start_time AND invalid_at <= end_time
|
||||
invalidated_filters = SearchFilters(
|
||||
invalid_at=[
|
||||
[
|
||||
DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than),
|
||||
DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
invalidated_results = await client.search_(
|
||||
query=query,
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=invalidated_filters,
|
||||
)
|
||||
|
||||
# Query 4: Facts added between start and end
|
||||
# created_at > start_time AND created_at <= end_time
|
||||
added_filters = SearchFilters(
|
||||
created_at=[
|
||||
[
|
||||
DateFilter(date=start_dt, comparison_operator=ComparisonOperator.greater_than),
|
||||
DateFilter(date=end_dt, comparison_operator=ComparisonOperator.less_than_equal),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
added_results = await client.search_(
|
||||
query=query,
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=added_filters,
|
||||
)
|
||||
|
||||
# Format results
|
||||
facts_from_start = [
|
||||
format_fact_result(edge)
|
||||
for edge in (start_results.edges[:max_facts_per_period] if start_results.edges else [])
|
||||
]
|
||||
|
||||
facts_at_end = [
|
||||
format_fact_result(edge)
|
||||
for edge in (end_results.edges[:max_facts_per_period] if end_results.edges else [])
|
||||
]
|
||||
|
||||
facts_invalidated = [
|
||||
format_fact_result(edge)
|
||||
for edge in (
|
||||
invalidated_results.edges[:max_facts_per_period]
|
||||
if invalidated_results.edges
|
||||
else []
|
||||
)
|
||||
]
|
||||
|
||||
facts_added = [
|
||||
format_fact_result(edge)
|
||||
for edge in (added_results.edges[:max_facts_per_period] if added_results.edges else [])
|
||||
]
|
||||
|
||||
return {
|
||||
'message': f'Comparison completed between {start_time} and {end_time}',
|
||||
'start_time': start_time,
|
||||
'end_time': end_time,
|
||||
'summary': {
|
||||
'facts_at_start_count': len(facts_from_start),
|
||||
'facts_at_end_count': len(facts_at_end),
|
||||
'facts_invalidated_count': len(facts_invalidated),
|
||||
'facts_added_count': len(facts_added),
|
||||
},
|
||||
'facts_from_start': facts_from_start,
|
||||
'facts_at_end': facts_at_end,
|
||||
'facts_invalidated': facts_invalidated,
|
||||
'facts_added': facts_added,
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error comparing facts over time: {error_msg}')
|
||||
return ErrorResponse(error=f'Error comparing facts over time: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
||||
"""Delete an entity edge from the graph memory.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue