Merge a0693dd0bf into 7c38ce7830
This commit is contained in:
commit
a7430b2187
3 changed files with 343 additions and 35 deletions
|
|
@ -100,6 +100,25 @@ The server uses the following environment variables:
|
|||
- `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name
|
||||
- `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version
|
||||
- `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication
|
||||
|
||||
### Gemini Configuration
|
||||
|
||||
The server also supports Google Gemini models for LLM, embedding, and cross-encoding operations:
|
||||
|
||||
- `GOOGLE_API_KEY` or `GEMINI_API_KEY`: Google API key for Gemini models
|
||||
- `USE_GEMINI`: Set to "true" to use Gemini models (automatically enabled when Gemini API key is provided)
|
||||
- `MODEL_NAME`: When using Gemini, defaults to "gemini-2.5-flash"
|
||||
- `SMALL_MODEL_NAME`: When using Gemini, defaults to "models/gemini-2.5-flash-lite-preview-06-17"
|
||||
- `EMBEDDER_MODEL_NAME`: When using Gemini, defaults to "embedding-001"
|
||||
|
||||
When Gemini is detected (via API key or USE_GEMINI=true), the server will automatically use Gemini for:
|
||||
|
||||
- LLM operations (text generation, entity extraction)
|
||||
- Embedding operations (text embeddings)
|
||||
- Cross-encoding operations (passage reranking)
|
||||
|
||||
### Other Configuration
|
||||
|
||||
- `SEMAPHORE_LIMIT`: Episode processing concurrency. See [Concurrency and LLM Provider 429 Rate Limit Errors](#concurrency-and-llm-provider-429-rate-limit-errors)
|
||||
|
||||
You can set these variables in a `.env` file in the project directory.
|
||||
|
|
@ -118,11 +137,25 @@ With options:
|
|||
uv run graphiti_mcp_server.py --model gpt-4.1-mini --transport sse
|
||||
```
|
||||
|
||||
Using Gemini:
|
||||
|
||||
```bash
|
||||
uv run graphiti_mcp_server.py --use-gemini --transport sse
|
||||
```
|
||||
|
||||
Or with a specific Gemini API key:
|
||||
|
||||
```bash
|
||||
uv run graphiti_mcp_server.py --gemini-api-key your-api-key --transport sse
|
||||
```
|
||||
|
||||
Available arguments:
|
||||
|
||||
- `--model`: Overrides the `MODEL_NAME` environment variable.
|
||||
- `--small-model`: Overrides the `SMALL_MODEL_NAME` environment variable.
|
||||
- `--temperature`: Overrides the `LLM_TEMPERATURE` environment variable.
|
||||
- `--use-gemini`: Use Gemini models instead of OpenAI (requires GOOGLE_API_KEY or GEMINI_API_KEY environment variable).
|
||||
- `--gemini-api-key`: Google API key for Gemini models (can also use GOOGLE_API_KEY or GEMINI_API_KEY environment variable).
|
||||
- `--transport`: Choose the transport method (sse or stdio, default: sse)
|
||||
- `--group-id`: Set a namespace for the graph (optional). If not provided, defaults to "default".
|
||||
- `--destroy-graph`: If set, destroys all Graphiti graphs on startup.
|
||||
|
|
@ -239,6 +272,37 @@ To use the Graphiti MCP server with an MCP-compatible client, configure it to co
|
|||
}
|
||||
```
|
||||
|
||||
For Gemini models, use this configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"graphiti-memory": {
|
||||
"transport": "stdio",
|
||||
"command": "/Users/<user>/.local/bin/uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--isolated",
|
||||
"--directory",
|
||||
"/Users/<user>/dev/zep/graphiti/mcp_server",
|
||||
"--project",
|
||||
".",
|
||||
"graphiti_mcp_server.py",
|
||||
"--transport",
|
||||
"stdio"
|
||||
],
|
||||
"env": {
|
||||
"NEO4J_URI": "bolt://localhost:7687",
|
||||
"NEO4J_USER": "neo4j",
|
||||
"NEO4J_PASSWORD": "password",
|
||||
"GOOGLE_API_KEY": "your-google-api-key",
|
||||
"USE_GEMINI": "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For SSE transport (HTTP-based), you can use this configuration:
|
||||
|
||||
```json
|
||||
|
|
|
|||
|
|
@ -19,13 +19,17 @@ from openai import AsyncAzureOpenAI
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.gemini_client import GeminiClient
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
|
|
@ -42,6 +46,11 @@ DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
|||
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
||||
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
||||
|
||||
# Gemini model defaults
|
||||
DEFAULT_GEMINI_LLM_MODEL = 'gemini-2.5-flash'
|
||||
DEFAULT_GEMINI_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
|
||||
DEFAULT_GEMINI_EMBEDDER_MODEL = 'embedding-001'
|
||||
|
||||
# Semaphore limit for concurrent Graphiti operations.
|
||||
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
||||
# Increase if you have high rate limits.
|
||||
|
|
@ -200,39 +209,65 @@ class GraphitiLLMConfig(BaseModel):
|
|||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
# Gemini-specific fields
|
||||
gemini_api_key: str | None = None
|
||||
use_gemini: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from environment variables."""
|
||||
# Get model from environment, or use default if not set or empty
|
||||
# Check if Gemini should be used
|
||||
gemini_api_key = os.environ.get('GOOGLE_API_KEY', None) or os.environ.get(
|
||||
'GEMINI_API_KEY', None
|
||||
)
|
||||
use_gemini = (
|
||||
gemini_api_key is not None
|
||||
or os.environ.get('USE_GEMINI', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Get model from environment, with appropriate defaults based on provider
|
||||
if use_gemini:
|
||||
default_model = DEFAULT_GEMINI_LLM_MODEL
|
||||
default_small_model = DEFAULT_GEMINI_SMALL_MODEL
|
||||
else:
|
||||
default_model = DEFAULT_LLM_MODEL
|
||||
default_small_model = SMALL_LLM_MODEL
|
||||
|
||||
model_env = os.environ.get('MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
||||
model = model_env if model_env.strip() else default_model
|
||||
|
||||
# Get small_model from environment, or use default if not set or empty
|
||||
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
||||
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
||||
small_model = (
|
||||
small_model_env if small_model_env.strip() else default_small_model
|
||||
)
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME', None
|
||||
)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower()
|
||||
== 'true'
|
||||
)
|
||||
|
||||
if azure_openai_endpoint is None:
|
||||
# Setup for OpenAI API
|
||||
# Setup for OpenAI or Gemini API
|
||||
# Log if empty model was provided
|
||||
if model_env == '':
|
||||
logger.debug(
|
||||
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
||||
f'MODEL_NAME environment variable not set, using default: {model}'
|
||||
)
|
||||
elif not model_env.strip():
|
||||
logger.warning(
|
||||
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
||||
f'Empty MODEL_NAME environment variable, using default: {model}'
|
||||
)
|
||||
|
||||
return cls(
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
gemini_api_key=gemini_api_key,
|
||||
use_gemini=use_gemini,
|
||||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
|
|
@ -241,9 +276,13 @@ class GraphitiLLMConfig(BaseModel):
|
|||
# Setup for Azure OpenAI API
|
||||
# Log if empty deployment name was provided
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
logger.error(
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set'
|
||||
)
|
||||
|
||||
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
raise ValueError(
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set'
|
||||
)
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
||||
|
|
@ -260,6 +299,8 @@ class GraphitiLLMConfig(BaseModel):
|
|||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
gemini_api_key=gemini_api_key,
|
||||
use_gemini=use_gemini,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -268,6 +309,19 @@ class GraphitiLLMConfig(BaseModel):
|
|||
# Start with environment-based config
|
||||
config = cls.from_env()
|
||||
|
||||
# Handle Gemini CLI arguments first, as they may affect model defaults
|
||||
if hasattr(args, 'use_gemini') and args.use_gemini:
|
||||
config.use_gemini = True
|
||||
# If switching to Gemini and no model was explicitly set, use Gemini defaults
|
||||
if hasattr(args, 'model') and not args.model:
|
||||
config.model = DEFAULT_GEMINI_LLM_MODEL
|
||||
if hasattr(args, 'small_model') and not args.small_model:
|
||||
config.small_model = DEFAULT_GEMINI_SMALL_MODEL
|
||||
|
||||
if hasattr(args, 'gemini_api_key') and args.gemini_api_key:
|
||||
config.gemini_api_key = args.gemini_api_key
|
||||
config.use_gemini = True
|
||||
|
||||
# CLI arguments override environment variables when provided
|
||||
if hasattr(args, 'model') and args.model:
|
||||
# Only use CLI model if it's not empty
|
||||
|
|
@ -275,13 +329,23 @@ class GraphitiLLMConfig(BaseModel):
|
|||
config.model = args.model
|
||||
else:
|
||||
# Log that empty model was provided and default is used
|
||||
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
||||
default_model = (
|
||||
DEFAULT_GEMINI_LLM_MODEL if config.use_gemini else DEFAULT_LLM_MODEL
|
||||
)
|
||||
logger.warning(
|
||||
f'Empty model name provided, using default: {default_model}'
|
||||
)
|
||||
|
||||
if hasattr(args, 'small_model') and args.small_model:
|
||||
if args.small_model.strip():
|
||||
config.small_model = args.small_model
|
||||
else:
|
||||
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
||||
default_small_model = (
|
||||
DEFAULT_GEMINI_SMALL_MODEL if config.use_gemini else SMALL_LLM_MODEL
|
||||
)
|
||||
logger.warning(
|
||||
f'Empty small_model name provided, using default: {default_small_model}'
|
||||
)
|
||||
|
||||
if hasattr(args, 'temperature') and args.temperature is not None:
|
||||
config.temperature = args.temperature
|
||||
|
|
@ -331,8 +395,27 @@ class GraphitiLLMConfig(BaseModel):
|
|||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
raise ValueError(
|
||||
'OPENAI_API_KEY must be set when using Azure OpenAI API'
|
||||
)
|
||||
|
||||
# Check if Gemini should be used
|
||||
if self.use_gemini:
|
||||
if not self.gemini_api_key:
|
||||
raise ValueError(
|
||||
'GOOGLE_API_KEY or GEMINI_API_KEY must be set when using Gemini'
|
||||
)
|
||||
|
||||
llm_client_config = LLMConfig(
|
||||
api_key=self.gemini_api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
return GeminiClient(config=llm_client_config)
|
||||
|
||||
# Default to OpenAI
|
||||
if not self.api_key:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
||||
|
||||
|
|
@ -358,22 +441,42 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
# Gemini-specific fields
|
||||
gemini_api_key: str | None = None
|
||||
use_gemini: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
||||
"""Create embedder configuration from environment variables."""
|
||||
|
||||
# Get model from environment, or use default if not set or empty
|
||||
# Check if Gemini should be used for embeddings
|
||||
gemini_api_key = os.environ.get('GOOGLE_API_KEY', None) or os.environ.get(
|
||||
'GEMINI_API_KEY', None
|
||||
)
|
||||
use_gemini = (
|
||||
gemini_api_key is not None
|
||||
or os.environ.get('USE_GEMINI', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Get model from environment, with appropriate defaults based on provider
|
||||
if use_gemini:
|
||||
default_model = DEFAULT_GEMINI_EMBEDDER_MODEL
|
||||
else:
|
||||
default_model = DEFAULT_EMBEDDER_MODEL
|
||||
|
||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||
model = model_env if model_env.strip() else default_model
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||
azure_openai_api_version = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_API_VERSION', None
|
||||
)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower()
|
||||
== 'true'
|
||||
)
|
||||
if azure_openai_endpoint is not None:
|
||||
# Setup for Azure OpenAI API
|
||||
|
|
@ -382,7 +485,9 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
||||
logger.error(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
||||
|
|
@ -390,9 +495,9 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||
'OPENAI_API_KEY', None
|
||||
)
|
||||
api_key = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_API_KEY', None
|
||||
) or os.environ.get('OPENAI_API_KEY', None)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
|
@ -403,11 +508,16 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
api_key=api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
use_gemini=use_gemini,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
model=model,
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
gemini_api_key=gemini_api_key,
|
||||
use_gemini=use_gemini,
|
||||
)
|
||||
|
||||
def create_client(self) -> EmbedderClient | None:
|
||||
|
|
@ -440,11 +550,27 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
return None
|
||||
else:
|
||||
# OpenAI API setup
|
||||
# Check if Gemini should be used
|
||||
if self.use_gemini:
|
||||
if not self.gemini_api_key:
|
||||
logger.warning(
|
||||
'GOOGLE_API_KEY or GEMINI_API_KEY must be set when using Gemini embeddings. Embedder will be disabled.'
|
||||
)
|
||||
return None
|
||||
|
||||
embedder_config = GeminiEmbedderConfig(
|
||||
api_key=self.gemini_api_key, embedding_model=self.model
|
||||
)
|
||||
|
||||
return GeminiEmbedder(config=embedder_config)
|
||||
|
||||
# Default to OpenAI API setup
|
||||
if not self.api_key:
|
||||
return None
|
||||
|
||||
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
||||
embedder_config = OpenAIEmbedderConfig(
|
||||
api_key=self.api_key, embedding_model=self.model
|
||||
)
|
||||
|
||||
return OpenAIEmbedder(config=embedder_config)
|
||||
|
||||
|
|
@ -466,6 +592,63 @@ class Neo4jConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class GraphitiCrossEncoderConfig(BaseModel):
|
||||
"""Configuration for the cross-encoder client.
|
||||
|
||||
Centralizes all cross-encoder/reranker configuration parameters.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
model: str | None = None
|
||||
# Gemini-specific fields
|
||||
gemini_api_key: str | None = None
|
||||
use_gemini: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiCrossEncoderConfig':
|
||||
"""Create cross-encoder configuration from environment variables."""
|
||||
|
||||
# Check if Gemini should be used for cross-encoding
|
||||
gemini_api_key = os.environ.get('GOOGLE_API_KEY', None) or os.environ.get(
|
||||
'GEMINI_API_KEY', None
|
||||
)
|
||||
use_gemini = (
|
||||
gemini_api_key is not None
|
||||
or os.environ.get('USE_GEMINI', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
return cls(
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
gemini_api_key=gemini_api_key,
|
||||
use_gemini=use_gemini,
|
||||
)
|
||||
|
||||
def create_client(self):
|
||||
"""Create a cross-encoder client based on this configuration.
|
||||
|
||||
Returns:
|
||||
CrossEncoderClient instance or None
|
||||
"""
|
||||
# Check if Gemini should be used
|
||||
if self.use_gemini:
|
||||
if not self.gemini_api_key:
|
||||
logger.warning(
|
||||
'GOOGLE_API_KEY or GEMINI_API_KEY must be set when using Gemini cross-encoder. Cross-encoder will be disabled.'
|
||||
)
|
||||
return None
|
||||
|
||||
cross_encoder_config = LLMConfig(api_key=self.gemini_api_key)
|
||||
return GeminiRerankerClient(config=cross_encoder_config)
|
||||
|
||||
# Default to OpenAI cross-encoder
|
||||
if not self.api_key:
|
||||
logger.warning('OPENAI_API_KEY not set. Cross-encoder will be disabled.')
|
||||
return None
|
||||
|
||||
cross_encoder_config = LLMConfig(api_key=self.api_key)
|
||||
return OpenAIRerankerClient(config=cross_encoder_config)
|
||||
|
||||
|
||||
class GraphitiConfig(BaseModel):
|
||||
"""Configuration for Graphiti client.
|
||||
|
||||
|
|
@ -474,6 +657,9 @@ class GraphitiConfig(BaseModel):
|
|||
|
||||
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
||||
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
||||
cross_encoder: GraphitiCrossEncoderConfig = Field(
|
||||
default_factory=GraphitiCrossEncoderConfig
|
||||
)
|
||||
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
||||
group_id: str | None = None
|
||||
use_custom_entities: bool = False
|
||||
|
|
@ -485,6 +671,7 @@ class GraphitiConfig(BaseModel):
|
|||
return cls(
|
||||
llm=GraphitiLLMConfig.from_env(),
|
||||
embedder=GraphitiEmbedderConfig.from_env(),
|
||||
cross_encoder=GraphitiCrossEncoderConfig.from_env(),
|
||||
neo4j=Neo4jConfig.from_env(),
|
||||
)
|
||||
|
||||
|
|
@ -581,13 +768,19 @@ async def initialize_graphiti():
|
|||
llm_client = config.llm.create_client()
|
||||
if not llm_client and config.use_custom_entities:
|
||||
# If custom entities are enabled, we must have an LLM client
|
||||
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
||||
provider = (
|
||||
'GOOGLE_API_KEY/GEMINI_API_KEY'
|
||||
if config.llm.use_gemini
|
||||
else 'OPENAI_API_KEY'
|
||||
)
|
||||
raise ValueError(f'{provider} must be set when custom entities are enabled')
|
||||
|
||||
# Validate Neo4j configuration
|
||||
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
|
||||
embedder_client = config.embedder.create_client()
|
||||
cross_encoder_client = config.cross_encoder.create_client()
|
||||
|
||||
# Initialize Graphiti client
|
||||
graphiti_client = Graphiti(
|
||||
|
|
@ -596,6 +789,7 @@ async def initialize_graphiti():
|
|||
password=config.neo4j.password,
|
||||
llm_client=llm_client,
|
||||
embedder=embedder_client,
|
||||
cross_encoder=cross_encoder_client,
|
||||
max_coroutines=SEMAPHORE_LIMIT,
|
||||
)
|
||||
|
||||
|
|
@ -610,11 +804,24 @@ async def initialize_graphiti():
|
|||
|
||||
# Log configuration details for transparency
|
||||
if llm_client:
|
||||
logger.info(f'Using OpenAI model: {config.llm.model}')
|
||||
provider = 'Gemini' if config.llm.use_gemini else 'OpenAI'
|
||||
logger.info(f'Using {provider} LLM model: {config.llm.model}')
|
||||
logger.info(f'Using temperature: {config.llm.temperature}')
|
||||
else:
|
||||
logger.info('No LLM client configured - entity extraction will be limited')
|
||||
|
||||
if embedder_client:
|
||||
provider = 'Gemini' if config.embedder.use_gemini else 'OpenAI'
|
||||
logger.info(f'Using {provider} embedder model: {config.embedder.model}')
|
||||
else:
|
||||
logger.info('No embedder client configured')
|
||||
|
||||
if cross_encoder_client:
|
||||
provider = 'Gemini' if config.cross_encoder.use_gemini else 'OpenAI'
|
||||
logger.info(f'Using {provider} cross-encoder')
|
||||
else:
|
||||
logger.info('No cross-encoder configured')
|
||||
|
||||
logger.info(f'Using group_id: {config.group_id}')
|
||||
logger.info(
|
||||
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
||||
|
|
@ -675,14 +882,18 @@ async def process_episode_queue(group_id: str):
|
|||
# Process the episode
|
||||
await process_func()
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
|
||||
logger.error(
|
||||
f'Error processing queued episode for group_id {group_id}: {str(e)}'
|
||||
)
|
||||
finally:
|
||||
# Mark the task as done regardless of success/failure
|
||||
episode_queues[group_id].task_done()
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
||||
except Exception as e:
|
||||
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
||||
logger.error(
|
||||
f'Unexpected error in queue worker for group_id {group_id}: {str(e)}'
|
||||
)
|
||||
finally:
|
||||
queue_workers[group_id] = False
|
||||
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
||||
|
|
@ -782,7 +993,9 @@ async def add_memory(
|
|||
# Define the episode processing function
|
||||
async def process_episode():
|
||||
try:
|
||||
logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}")
|
||||
logger.info(
|
||||
f"Processing queued episode '{name}' for group_id: {group_id_str}"
|
||||
)
|
||||
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
||||
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
||||
|
||||
|
|
@ -854,7 +1067,11 @@ async def search_memory_nodes(
|
|||
try:
|
||||
# Use the provided group_ids or fall back to the default from config if none provided
|
||||
effective_group_ids = (
|
||||
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [config.group_id]
|
||||
if config.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
# Configure the search
|
||||
|
|
@ -900,7 +1117,9 @@ async def search_memory_nodes(
|
|||
for node in search_results.nodes
|
||||
]
|
||||
|
||||
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=formatted_nodes)
|
||||
return NodeSearchResponse(
|
||||
message='Nodes retrieved successfully', nodes=formatted_nodes
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error searching nodes: {error_msg}')
|
||||
|
|
@ -934,7 +1153,11 @@ async def search_memory_facts(
|
|||
|
||||
# Use the provided group_ids or fall back to the default from config if none provided
|
||||
effective_group_ids = (
|
||||
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [config.group_id]
|
||||
if config.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -984,7 +1207,9 @@ async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|||
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
||||
# Delete the edge using its delete method
|
||||
await entity_edge.delete(client.driver)
|
||||
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
|
||||
return SuccessResponse(
|
||||
message=f'Entity edge with UUID {uuid} deleted successfully'
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error deleting entity edge: {error_msg}')
|
||||
|
|
@ -1081,7 +1306,9 @@ async def get_episodes(
|
|||
client = cast(Graphiti, graphiti_client)
|
||||
|
||||
episodes = await client.retrieve_episodes(
|
||||
group_ids=[effective_group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
||||
group_ids=[effective_group_id],
|
||||
last_n=last_n,
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
if not episodes:
|
||||
|
|
@ -1178,7 +1405,8 @@ async def initialize_server() -> MCPConfig:
|
|||
help='Transport to use for communication with the client. (default: sse)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})'
|
||||
'--model',
|
||||
help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--small-model',
|
||||
|
|
@ -1189,7 +1417,18 @@ async def initialize_server() -> MCPConfig:
|
|||
type=float,
|
||||
help='Temperature setting for the LLM (0.0-2.0). Lower values make output more deterministic. (default: 0.7)',
|
||||
)
|
||||
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
|
||||
parser.add_argument(
|
||||
'--use-gemini',
|
||||
action='store_true',
|
||||
help='Use Gemini models instead of OpenAI (requires GOOGLE_API_KEY or GEMINI_API_KEY environment variable)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gemini-api-key',
|
||||
help='Google API key for Gemini models (can also use GOOGLE_API_KEY or GEMINI_API_KEY environment variable)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--destroy-graph', action='store_true', help='Destroy all Graphiti graphs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use-custom-entities',
|
||||
action='store_true',
|
||||
|
|
|
|||
|
|
@ -11,3 +11,8 @@ dependencies = [
|
|||
"azure-identity>=1.21.0",
|
||||
"graphiti-core",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "single"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
Loading…
Add table
Reference in a new issue