diff --git a/README.md b/README.md index 005e0b0f..6f813118 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,7 @@ from openai import AsyncAzureOpenAI from graphiti_core import Graphiti from graphiti_core.llm_client import OpenAIClient from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient # Azure OpenAI configuration api_key = "" @@ -231,6 +232,10 @@ graphiti = Graphiti( embedding_model="text-embedding-3-small" # Use your Azure deployed embedding model name ), client=azure_openai_client + ), + # Optional: Configure the OpenAI cross encoder with Azure OpenAI + cross_encoder=OpenAIRerankerClient( + client=azure_openai_client ) ) diff --git a/graphiti_core/cross_encoder/__init__.py b/graphiti_core/cross_encoder/__init__.py index e69de29b..2399f773 100644 --- a/graphiti_core/cross_encoder/__init__.py +++ b/graphiti_core/cross_encoder/__init__.py @@ -0,0 +1,21 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .bge_reranker_client import BGERerankerClient +from .client import CrossEncoderClient +from .openai_reranker_client import OpenAIRerankerClient + +__all__ = ['CrossEncoderClient', 'BGERerankerClient', 'OpenAIRerankerClient'] diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index e41cb61e..0694f043 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -18,7 +18,7 @@ import logging from typing import Any import openai -from openai import AsyncOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI from pydantic import BaseModel from ..helpers import semaphore_gather @@ -36,21 +36,29 @@ class BooleanClassifier(BaseModel): class OpenAIRerankerClient(CrossEncoderClient): - def __init__(self, config: LLMConfig | None = None): + def __init__( + self, + config: LLMConfig | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | None = None, + ): """ - Initialize the OpenAIClient with the provided configuration, cache setting, and client. + Initialize the OpenAIRerankerClient with the provided configuration and client. + + This reranker uses the OpenAI API to run a simple boolean classifier prompt concurrently + for each passage. Log-probabilities are used to rank the passages. Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - cache (bool): Whether to use caching for responses. Defaults to False. - client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. - + client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() self.config = config - self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + if client is None: + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + else: + self.client = client async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: openai_messages_list: Any = [ @@ -62,7 +70,7 @@ class OpenAIRerankerClient(CrossEncoderClient): Message( role='user', content=f""" - Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise. + Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise. {passage}