From 3d7e1a4b794f8fb676c78d8aa9679d4156a3f080 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 13 Jun 2025 08:55:08 -0700 Subject: [PATCH] feat: add Azure client wrappers for embedding and LLM, integrate into server (#581) * create wrappers for azure clients * rremove unused crossencoder client * format * chore: update graphiti-core to 0.12.0rc5 and pydantic to 2.11.5 * Update graphiti_core/llm_client/azure_openai_client.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- .../cross_encoder/openai_reranker_client.py | 2 +- graphiti_core/embedder/azure_openai.py | 64 +++++++++++++ .../llm_client/azure_openai_client.py | 73 +++++++++++++++ mcp_server/graphiti_mcp_server.py | 89 ++++++++++--------- mcp_server/pyproject.toml | 2 +- mcp_server/uv.lock | 28 ++++-- 6 files changed, 207 insertions(+), 51 deletions(-) create mode 100644 graphiti_core/embedder/azure_openai.py create mode 100644 graphiti_core/llm_client/azure_openai_client.py diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 8cc03f5c..3901be9e 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient): if len(top_logprobs) == 0: continue norm_logprobs = np.exp(top_logprobs[0].logprob) - if top_logprobs[0].token.strip().split(" ")[0].lower() == "true": + if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true': scores.append(norm_logprobs) else: scores.append(1 - norm_logprobs) diff --git a/graphiti_core/embedder/azure_openai.py b/graphiti_core/embedder/azure_openai.py new file mode 100644 index 00000000..d0834e1f --- /dev/null +++ b/graphiti_core/embedder/azure_openai.py @@ -0,0 +1,64 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from typing import Any + +from openai import AsyncAzureOpenAI + +from .client import EmbedderClient + +logger = logging.getLogger(__name__) + + +class AzureOpenAIEmbedderClient(EmbedderClient): + """Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface.""" + + def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'): + self.azure_client = azure_client + self.model = model + + async def create(self, input_data: str | list[str] | Any) -> list[float]: + """Create embeddings using Azure OpenAI client.""" + try: + # Handle different input types + if isinstance(input_data, str): + text_input = [input_data] + elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data): + text_input = input_data + else: + # Convert to string list for other types + text_input = [str(input_data)] + + response = await self.azure_client.embeddings.create(model=self.model, input=text_input) + + # Return the first embedding as a list of floats + return response.data[0].embedding + except Exception as e: + logger.error(f'Error in Azure OpenAI embedding: {e}') + raise + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + """Create batch embeddings using Azure OpenAI client.""" + try: + response = await self.azure_client.embeddings.create( + model=self.model, input=input_data_list + ) + + return [embedding.embedding for embedding in response.data] + except Exception as e: + logger.error(f'Error in Azure OpenAI batch embedding: {e}') + raise diff --git a/graphiti_core/llm_client/azure_openai_client.py b/graphiti_core/llm_client/azure_openai_client.py new file mode 100644 index 00000000..60787145 --- /dev/null +++ b/graphiti_core/llm_client/azure_openai_client.py @@ -0,0 +1,73 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import logging +from typing import Any + +from openai import AsyncAzureOpenAI +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel + +from ..prompts.models import Message +from .client import LLMClient +from .config import LLMConfig, ModelSize + +logger = logging.getLogger(__name__) + + +class AzureOpenAILLMClient(LLMClient): + """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface.""" + + def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None): + super().__init__(config, cache=False) + self.azure_client = azure_client + + async def _generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int = 1024, + model_size: ModelSize = ModelSize.medium, + ) -> dict[str, Any]: + """Generate response using Azure OpenAI client.""" + # Convert messages to OpenAI format + openai_messages: list[ChatCompletionMessageParam] = [] + for message in messages: + message.content = self._clean_input(message.content) + if message.role == 'user': + openai_messages.append({'role': 'user', 'content': message.content}) + elif message.role == 'system': + openai_messages.append({'role': 'system', 'content': message.content}) + + # Ensure model is a string + model_name = self.model if self.model else 'gpt-4o-mini' + + try: + response = await self.azure_client.chat.completions.create( + model=model_name, + messages=openai_messages, + temperature=float(self.temperature) if self.temperature is not None else 0.7, + max_tokens=max_tokens, + response_format={'type': 'json_object'}, + ) + result = response.choices[0].message.content or '{}' + + # Parse JSON response + return json.loads(result) + except Exception as e: + logger.error(f'Error in Azure OpenAI LLM response: {e}') + raise diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index 9b8330b4..d60e2014 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -19,12 +19,12 @@ from openai import AsyncAzureOpenAI from pydantic import BaseModel, Field from graphiti_core import Graphiti -from graphiti_core.cross_encoder.client import CrossEncoderClient -from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.edges import EntityEdge +from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient from graphiti_core.embedder.client import EmbedderClient from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig from graphiti_core.llm_client import LLMClient +from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient from graphiti_core.llm_client.config import LLMConfig from graphiti_core.llm_client.openai_client import OpenAIClient from graphiti_core.nodes import EpisodeType, EpisodicNode @@ -37,6 +37,7 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() + DEFAULT_LLM_MODEL = 'gpt-4.1-mini' SMALL_LLM_MODEL = 'gpt-4.1-nano' DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small' @@ -282,11 +283,11 @@ class GraphitiLLMConfig(BaseModel): return config - def create_client(self) -> LLMClient | None: + def create_client(self) -> LLMClient: """Create an LLM client based on this configuration. Returns: - LLMClient instance if able, None otherwise + LLMClient instance """ if self.azure_openai_endpoint is not None: @@ -294,26 +295,41 @@ class GraphitiLLMConfig(BaseModel): if self.azure_openai_use_managed_identity: # Use managed identity for authentication token_provider = create_azure_credential_token_provider() - return AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - azure_ad_token_provider=token_provider, + return AzureOpenAILLMClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + azure_ad_token_provider=token_provider, + ), + config=LLMConfig( + api_key=self.api_key, + model=self.model, + small_model=self.small_model, + temperature=self.temperature, + ), ) elif self.api_key: # Use API key for authentication - return AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - api_key=self.api_key, + return AzureOpenAILLMClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + api_key=self.api_key, + ), + config=LLMConfig( + api_key=self.api_key, + model=self.model, + small_model=self.small_model, + temperature=self.temperature, + ), ) else: - logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API') - return None + raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API') if not self.api_key: - return None + raise ValueError('OPENAI_API_KEY must be set when using OpenAI API') llm_client_config = LLMConfig( api_key=self.api_key, model=self.model, small_model=self.small_model @@ -324,17 +340,6 @@ class GraphitiLLMConfig(BaseModel): return OpenAIClient(config=llm_client_config) - def create_cross_encoder_client(self) -> CrossEncoderClient | None: - """Create a cross-encoder client based on this configuration.""" - if self.azure_openai_endpoint is not None: - client = self.create_client() - return OpenAIRerankerClient(client=client) - else: - llm_client_config = LLMConfig( - api_key=self.api_key, model=self.model, small_model=self.small_model - ) - return OpenAIRerankerClient(config=llm_client_config) - class GraphitiEmbedderConfig(BaseModel): """Configuration for the embedder client. @@ -404,19 +409,25 @@ class GraphitiEmbedderConfig(BaseModel): if self.azure_openai_use_managed_identity: # Use managed identity for authentication token_provider = create_azure_credential_token_provider() - return AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - azure_ad_token_provider=token_provider, + return AzureOpenAIEmbedderClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + azure_ad_token_provider=token_provider, + ), + model=self.model, ) elif self.api_key: # Use API key for authentication - return AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - api_key=self.api_key, + return AzureOpenAIEmbedderClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + api_key=self.api_key, + ), + model=self.model, ) else: logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API') @@ -570,7 +581,6 @@ async def initialize_graphiti(): raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set') embedder_client = config.embedder.create_client() - cross_encoder_client = config.llm.create_cross_encoder_client() # Initialize Graphiti client graphiti_client = Graphiti( @@ -579,7 +589,6 @@ async def initialize_graphiti(): password=config.neo4j.password, llm_client=llm_client, embedder=embedder_client, - cross_encoder=cross_encoder_client, ) # Destroy graph if requested diff --git a/mcp_server/pyproject.toml b/mcp_server/pyproject.toml index 793fe6dc..01dfa29d 100644 --- a/mcp_server/pyproject.toml +++ b/mcp_server/pyproject.toml @@ -7,6 +7,6 @@ requires-python = ">=3.10" dependencies = [ "mcp>=1.5.0", "openai>=1.68.2", - "graphiti-core>=0.8.2", + "graphiti-core>=0.11.6", "azure-identity>=1.21.0", ] diff --git a/mcp_server/uv.lock b/mcp_server/uv.lock index ac91f207..e9c3a973 100644 --- a/mcp_server/uv.lock +++ b/mcp_server/uv.lock @@ -282,8 +282,8 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.11.6" -source = { registry = "https://pypi.org/simple" } +version = "0.12.0rc5" +source = { directory = "../" } dependencies = [ { name = "diskcache" }, { name = "neo4j" }, @@ -293,9 +293,19 @@ dependencies = [ { name = "python-dotenv" }, { name = "tenacity" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001 }, + +[package.metadata] +requires-dist = [ + { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, + { name = "diskcache", specifier = ">=5.6.3" }, + { name = "google-genai", marker = "extra == 'google-genai'", specifier = ">=1.8.0" }, + { name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" }, + { name = "neo4j", specifier = ">=5.26.0" }, + { name = "numpy", specifier = ">=1.0.0" }, + { name = "openai", specifier = ">=1.53.0" }, + { name = "pydantic", specifier = ">=2.11.5" }, + { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "tenacity", specifier = ">=9.0.0" }, ] [[package]] @@ -459,7 +469,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "azure-identity", specifier = ">=1.21.0" }, - { name = "graphiti-core", specifier = ">=0.8.2" }, + { name = "graphiti-core", directory = "../" }, { name = "mcp", specifier = ">=1.5.0" }, { name = "openai", specifier = ">=1.68.2" }, ] @@ -594,7 +604,7 @@ wheels = [ [[package]] name = "pydantic" -version = "2.11.4" +version = "2.11.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-types" }, @@ -602,9 +612,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } +sdist = { url = "https://files.pythonhosted.org/packages/f0/86/8ce9040065e8f924d642c58e4a344e33163a07f6b57f836d0d734e0ad3fb/pydantic-2.11.5.tar.gz", hash = "sha256:7f853db3d0ce78ce8bbb148c401c2cdd6431b3473c0cdff2755c7690952a7b7a", size = 787102 } wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, + { url = "https://files.pythonhosted.org/packages/b5/69/831ed22b38ff9b4b64b66569f0e5b7b97cf3638346eb95a2147fdb49ad5f/pydantic-2.11.5-py3-none-any.whl", hash = "sha256:f9c26ba06f9747749ca1e5c94d6a85cb84254577553c8785576fd38fa64dc0f7", size = 444229 }, ] [[package]]