graphiti/mcp_server/embedder_config.py
Daniel Chalef 452a45cb4e wip
2025-08-30 08:50:48 -07:00

124 lines
4.9 KiB
Python

"""Embedder configuration for Graphiti MCP Server."""
import logging
import os
from openai import AsyncAzureOpenAI
from pydantic import BaseModel
from utils import create_azure_credential_token_provider
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
logger = logging.getLogger(__name__)
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
class GraphitiEmbedderConfig(BaseModel):
"""Configuration for the embedder client.
Centralizes all embedding-related configuration parameters.
"""
model: str = DEFAULT_EMBEDDER_MODEL
api_key: str | None = None
azure_openai_endpoint: str | None = None
azure_openai_deployment_name: str | None = None
azure_openai_api_version: str | None = None
azure_openai_use_managed_identity: bool = False
@classmethod
def from_env(cls) -> 'GraphitiEmbedderConfig':
"""Create embedder configuration from environment variables."""
# Get model from environment, or use default if not set or empty
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
azure_openai_use_managed_identity = (
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
)
if azure_openai_endpoint is not None:
# Setup for Azure OpenAI API
# Log if empty deployment name was provided
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
if azure_openai_deployment_name is None:
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
raise ValueError(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
)
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
'OPENAI_API_KEY', None
)
else:
# Managed identity
api_key = None
return cls(
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
azure_openai_endpoint=azure_openai_endpoint,
api_key=api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_deployment_name=azure_openai_deployment_name,
model=model,
)
else:
return cls(
model=model,
api_key=os.environ.get('OPENAI_API_KEY'),
)
def create_client(self) -> EmbedderClient | None:
"""Create an embedder client based on this configuration.
Returns:
EmbedderClient instance or None if configuration is invalid
"""
if self.azure_openai_endpoint is not None:
# Azure OpenAI API setup
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
model=self.model,
)
elif self.api_key:
# Use API key for authentication
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
model=self.model,
)
else:
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
return None
else:
# OpenAI API setup
if not self.api_key:
return None
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
return OpenAIEmbedder(config=embedder_config)