Allow usage of different openai compatible clients in embedder and encoder (#279)
* allow usage of different openai compatible clients in embedder and encoder * azure openai * cross encoder example --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
55e308fb9f
commit
5cad6c8504
3 changed files with 42 additions and 8 deletions
|
|
@ -205,6 +205,7 @@ from openai import AsyncAzureOpenAI
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.llm_client import OpenAIClient
|
from graphiti_core.llm_client import OpenAIClient
|
||||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||||
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
|
||||||
# Azure OpenAI configuration
|
# Azure OpenAI configuration
|
||||||
api_key = "<your-api-key>"
|
api_key = "<your-api-key>"
|
||||||
|
|
@ -231,6 +232,10 @@ graphiti = Graphiti(
|
||||||
embedding_model="text-embedding-3-small" # Use your Azure deployed embedding model name
|
embedding_model="text-embedding-3-small" # Use your Azure deployed embedding model name
|
||||||
),
|
),
|
||||||
client=azure_openai_client
|
client=azure_openai_client
|
||||||
|
),
|
||||||
|
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
|
||||||
|
cross_encoder=OpenAIRerankerClient(
|
||||||
|
client=azure_openai_client
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""
|
||||||
|
Copyright 2025, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .bge_reranker_client import BGERerankerClient
|
||||||
|
from .client import CrossEncoderClient
|
||||||
|
from .openai_reranker_client import OpenAIRerankerClient
|
||||||
|
|
||||||
|
__all__ = ['CrossEncoderClient', 'BGERerankerClient', 'OpenAIRerankerClient']
|
||||||
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..helpers import semaphore_gather
|
from ..helpers import semaphore_gather
|
||||||
|
|
@ -36,21 +36,29 @@ class BooleanClassifier(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRerankerClient(CrossEncoderClient):
|
class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
def __init__(self, config: LLMConfig | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
||||||
|
|
||||||
|
This reranker uses the OpenAI API to run a simple boolean classifier prompt concurrently
|
||||||
|
for each passage. Log-probabilities are used to rank the passages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
if client is None:
|
||||||
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||||
openai_messages_list: Any = [
|
openai_messages_list: Any = [
|
||||||
|
|
@ -62,7 +70,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||||
<PASSAGE>
|
<PASSAGE>
|
||||||
{passage}
|
{passage}
|
||||||
</PASSAGE>
|
</PASSAGE>
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue