diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index de375384..617507fc 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -8,7 +8,6 @@ import asyncio import logging import os import sys -from datetime import datetime from pathlib import Path from typing import Any, Optional @@ -16,9 +15,6 @@ from dotenv import load_dotenv from graphiti_core import Graphiti from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodeType, EpisodicNode -from graphiti_core.search.search_config_recipes import ( - NODE_HYBRID_SEARCH_RRF, -) from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.maintenance.graph_data_operations import clear_data from mcp.server.fastmcp import FastMCP @@ -279,7 +275,7 @@ async def add_memory( source_description=source_description, episode_type=episode_type, entity_types=graphiti_service.entity_types, - uuid=uuid, + uuid=uuid or None, # Ensure None is passed if uuid is None ) return SuccessResponse( @@ -325,18 +321,22 @@ async def search_nodes( # Create search filters search_filters = SearchFilters( - group_ids=effective_group_ids, node_labels=entity_types, ) - # Perform the search - nodes = await client.search_nodes( + # Use the search_ method with node search config + from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + + results = await client.search_( query=query, - limit=max_nodes, - search_config=NODE_HYBRID_SEARCH_RRF, - search_filters=search_filters, + config=NODE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=search_filters, ) + # Extract nodes from results + nodes = results.nodes[:max_nodes] if results.nodes else [] + if not nodes: return NodeSearchResponse(message='No relevant nodes found', nodes=[]) @@ -493,21 +493,15 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: @mcp.tool() -async def search_episodes( - query: str | None = None, +async def get_episodes( group_ids: list[str] | None = None, max_episodes: int = 10, - start_date: str | None = None, - end_date: str | None = None, ) -> EpisodeSearchResponse | ErrorResponse: - """Search for episodes in the graph memory. + """Get episodes from the graph memory. Args: - query: Optional search query for semantic search group_ids: Optional list of group IDs to filter results max_episodes: Maximum number of episodes to return (default: 10) - start_date: Optional start date (ISO format) to filter episodes - end_date: Optional end date (ISO format) to filter episodes """ global graphiti_service @@ -526,18 +520,17 @@ async def search_episodes( else [] ) - # Convert date strings to datetime objects if provided - start_dt = datetime.fromisoformat(start_date) if start_date else None - end_dt = datetime.fromisoformat(end_date) if end_date else None + # Get episodes from the driver directly + from graphiti_core.nodes import EpisodicNode - # Search for episodes - episodes = await client.search_episodes( - query=query, - group_ids=effective_group_ids, - limit=max_episodes, - start_date=start_dt, - end_date=end_dt, - ) + if effective_group_ids: + episodes = await EpisodicNode.get_by_group_ids( + client.driver, effective_group_ids, limit=max_episodes + ) + else: + # If no group IDs, we need to use a different approach + # For now, return empty list when no group IDs specified + episodes = [] if not episodes: return EpisodeSearchResponse(message='No episodes found', episodes=[]) @@ -550,7 +543,9 @@ async def search_episodes( 'name': episode.name, 'content': episode.content, 'created_at': episode.created_at.isoformat() if episode.created_at else None, - 'source': episode.source, + 'source': episode.source.value + if hasattr(episode.source, 'value') + else str(episode.source), 'source_description': episode.source_description, 'group_id': episode.group_id, } @@ -561,8 +556,8 @@ async def search_episodes( ) except Exception as e: error_msg = str(e) - logger.error(f'Error searching episodes: {error_msg}') - return ErrorResponse(error=f'Error searching episodes: {error_msg}') + logger.error(f'Error getting episodes: {error_msg}') + return ErrorResponse(error=f'Error getting episodes: {error_msg}') @mcp.tool() diff --git a/mcp_server/src/models/response_types.py b/mcp_server/src/models/response_types.py index ac9a9844..81032afc 100644 --- a/mcp_server/src/models/response_types.py +++ b/mcp_server/src/models/response_types.py @@ -14,11 +14,9 @@ class SuccessResponse(TypedDict): class NodeResult(TypedDict): uuid: str name: str - summary: str - labels: list[str] - group_id: str - created_at: str - attributes: dict[str, Any] + type: str + created_at: str | None + summary: str | None class NodeSearchResponse(TypedDict): diff --git a/mcp_server/src/services/factories.py b/mcp_server/src/services/factories.py index 020d10ae..52131734 100644 --- a/mcp_server/src/services/factories.py +++ b/mcp_server/src/services/factories.py @@ -1,5 +1,7 @@ """Factory classes for creating LLM, Embedder, and Database clients.""" +from openai import AsyncAzureOpenAI + from config.schema import ( DatabaseConfig, EmbedderConfig, @@ -8,17 +10,18 @@ from config.schema import ( # Try to import FalkorDriver if available try: - from graphiti_core.driver import FalkorDriver # noqa: F401 + from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401 HAS_FALKOR = True except ImportError: HAS_FALKOR = False from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.llm_client import LLMClient, OpenAIClient +from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig # Try to import additional providers if available try: - from graphiti_core.embedder import AzureOpenAIEmbedderClient + from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient HAS_AZURE_EMBEDDER = True except ImportError: @@ -32,14 +35,14 @@ except ImportError: HAS_GEMINI_EMBEDDER = False try: - from graphiti_core.embedder.voyage import VoyageEmbedder + from graphiti_core.embedder.voyage import VoyageAIEmbedder HAS_VOYAGE_EMBEDDER = True except ImportError: HAS_VOYAGE_EMBEDDER = False try: - from graphiti_core.llm_client import AzureOpenAILLMClient + from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient HAS_AZURE_LLM = True except ImportError: @@ -108,17 +111,32 @@ class LLMClientFactory: else: api_key = azure_config.api_key - return AzureOpenAILLMClient( + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( api_key=api_key, - api_url=azure_config.api_url, + azure_endpoint=azure_config.api_url, api_version=azure_config.api_version, azure_deployment=azure_config.deployment_name, azure_ad_token_provider=azure_ad_token_provider, + ) + + # Then create the LLMConfig + from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig + + llm_config = CoreLLMConfig( + api_key=api_key, + base_url=azure_config.api_url, model=config.model, temperature=config.temperature, max_tokens=config.max_tokens, ) + return AzureOpenAILLMClient( + azure_client=azure_client, + config=llm_config, + max_tokens=config.max_tokens, + ) + case 'anthropic': if not HAS_ANTHROPIC: raise ValueError( @@ -126,37 +144,40 @@ class LLMClientFactory: ) if not config.providers.anthropic: raise ValueError('Anthropic provider configuration not found') - return AnthropicClient( + llm_config = GraphitiLLMConfig( api_key=config.providers.anthropic.api_key, model=config.model, temperature=config.temperature, max_tokens=config.max_tokens, ) + return AnthropicClient(config=llm_config) case 'gemini': if not HAS_GEMINI: raise ValueError('Gemini client not available in current graphiti-core version') if not config.providers.gemini: raise ValueError('Gemini provider configuration not found') - return GeminiClient( + llm_config = GraphitiLLMConfig( api_key=config.providers.gemini.api_key, model=config.model, temperature=config.temperature, max_tokens=config.max_tokens, ) + return GeminiClient(config=llm_config) case 'groq': if not HAS_GROQ: raise ValueError('Groq client not available in current graphiti-core version') if not config.providers.groq: raise ValueError('Groq provider configuration not found') - return GroqClient( + llm_config = GraphitiLLMConfig( api_key=config.providers.groq.api_key, - api_url=config.providers.groq.api_url, + base_url=config.providers.groq.api_url, model=config.model, temperature=config.temperature, max_tokens=config.max_tokens, ) + return GroqClient(config=llm_config) case _: raise ValueError(f'Unsupported LLM provider: {provider}') @@ -201,14 +222,18 @@ class EmbedderFactory: else: api_key = azure_config.api_key - return AzureOpenAIEmbedderClient( + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( api_key=api_key, - api_url=azure_config.api_url, + azure_endpoint=azure_config.api_url, api_version=azure_config.api_version, azure_deployment=azure_config.deployment_name, azure_ad_token_provider=azure_ad_token_provider, - model=config.model, - dimensions=config.dimensions, + ) + + return AzureOpenAIEmbedderClient( + azure_client=azure_client, + model=config.model or 'text-embedding-3-small', ) case 'gemini': @@ -218,11 +243,14 @@ class EmbedderFactory: ) if not config.providers.gemini: raise ValueError('Gemini provider configuration not found') - return GeminiEmbedder( + from graphiti_core.embedder.gemini import GeminiEmbedderConfig + + gemini_config = GeminiEmbedderConfig( api_key=config.providers.gemini.api_key, - model=config.model, - dimensions=config.dimensions, + embedding_model=config.model or 'models/text-embedding-004', + embedding_dim=config.dimensions or 768, ) + return GeminiEmbedder(config=gemini_config) case 'voyage': if not HAS_VOYAGE_EMBEDDER: @@ -231,10 +259,14 @@ class EmbedderFactory: ) if not config.providers.voyage: raise ValueError('Voyage provider configuration not found') - return VoyageEmbedder( + from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig + + voyage_config = VoyageAIEmbedderConfig( api_key=config.providers.voyage.api_key, - model=config.providers.voyage.model, + embedding_model=config.model or 'voyage-3', + embedding_dim=config.dimensions or 1024, ) + return VoyageAIEmbedder(config=voyage_config) case _: raise ValueError(f'Unsupported Embedder provider: {provider}') diff --git a/mcp_server/src/services/queue_service.py b/mcp_server/src/services/queue_service.py index 17621208..93a0ba8e 100644 --- a/mcp_server/src/services/queue_service.py +++ b/mcp_server/src/services/queue_service.py @@ -105,7 +105,7 @@ class QueueService: source_description: str, episode_type: Any, entity_types: Any, - uuid: str, + uuid: str | None, ) -> int: """Add an episode for processing.