From bb476ff0bdba65a72e4774e09dbd94ed5322f8a1 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Sat, 28 Jun 2025 11:37:34 -0700 Subject: [PATCH] improve client; add reranker --- README.md | 14 +- graphiti_core/cross_encoder/__init__.py | 3 +- .../cross_encoder/gemini_reranker_client.py | 216 ++++++++++++++++++ graphiti_core/llm_client/gemini_client.py | 138 +++++++++-- uv.lock | 32 ++- 5 files changed, 381 insertions(+), 22 deletions(-) create mode 100644 graphiti_core/cross_encoder/gemini_reranker_client.py diff --git a/README.md b/README.md index 27dbaafc..e5deba7e 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,6 @@ graphiti = Graphiti( ), client=azure_openai_client ), - # Optional: Configure the OpenAI cross encoder with Azure OpenAI cross_encoder=OpenAIRerankerClient( llm_config=azure_llm_config, client=azure_openai_client @@ -256,7 +255,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden ## Using Graphiti with Google Gemini -Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key. +Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key. Install Graphiti: @@ -272,6 +271,7 @@ pip install "graphiti-core[google-genai]" from graphiti_core import Graphiti from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig +from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient # Google API key configuration api_key = "" @@ -292,12 +292,20 @@ graphiti = Graphiti( api_key=api_key, embedding_model="embedding-001" ) + ), + cross_encoder=GeminiRerankerClient( + config=LLMConfig( + api_key=api_key, + model="gemini-2.5-flash-lite-preview-06-17" + ) ) ) -# Now you can use Graphiti with Google Gemini +# Now you can use Graphiti with Google Gemini for all components ``` +The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance. + ## Using Graphiti with Ollama (Local LLM) Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs. diff --git a/graphiti_core/cross_encoder/__init__.py b/graphiti_core/cross_encoder/__init__.py index 64a231cf..d4fb7281 100644 --- a/graphiti_core/cross_encoder/__init__.py +++ b/graphiti_core/cross_encoder/__init__.py @@ -15,6 +15,7 @@ limitations under the License. """ from .client import CrossEncoderClient +from .gemini_reranker_client import GeminiRerankerClient from .openai_reranker_client import OpenAIRerankerClient -__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient'] +__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient'] diff --git a/graphiti_core/cross_encoder/gemini_reranker_client.py b/graphiti_core/cross_encoder/gemini_reranker_client.py new file mode 100644 index 00000000..34a3cbe2 --- /dev/null +++ b/graphiti_core/cross_encoder/gemini_reranker_client.py @@ -0,0 +1,216 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import math +from typing import Any + +from google import genai # type: ignore +from google.genai import types # type: ignore + +from ..helpers import semaphore_gather +from ..llm_client import LLMConfig, RateLimitError +from .client import CrossEncoderClient + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17' + + +class GeminiRerankerClient(CrossEncoderClient): + def __init__( + self, + config: LLMConfig | None = None, + client: genai.Client | None = None, + ): + """ + Initialize the GeminiRerankerClient with the provided configuration and client. + + This reranker uses the Gemini API to run a simple boolean classifier prompt concurrently + for each passage. Log-probabilities are used to rank the passages, equivalent to the OpenAI approach. + + Args: + config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. + client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created. + """ + if config is None: + config = LLMConfig() + + self.config = config + if client is None: + self.client = genai.Client(api_key=config.api_key) + else: + self.client = client + + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: + """ + Rank passages based on their relevance to the query using Gemini. + + Uses log probabilities from boolean classification responses, equivalent to the OpenAI approach. + The model responds with "True" or "False" and we use the log probabilities of these tokens. + """ + gemini_messages_list: Any = [ + [ + types.Content( + role='system', + parts=[ + types.Part.from_text( + text='You are an expert tasked with determining whether the passage is relevant to the query' + ) + ], + ), + types.Content( + role='user', + parts=[ + types.Part.from_text( + text=f"""Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise. + +{passage} + + +{query} +""" + ) + ], + ), + ] + for passage in passages + ] + + try: + responses = await semaphore_gather( + *[ + self.client.aio.models.generate_content( + model=self.config.model or DEFAULT_MODEL, + contents=gemini_messages, + config=types.GenerateContentConfig( + temperature=0.0, + max_output_tokens=1, + response_logprobs=True, + logprobs=5, # Get top 5 candidate tokens for better coverage + ), + ) + for gemini_messages in gemini_messages_list + ] + ) + + scores: list[float] = [] + for response in responses: + try: + # Check if we have logprobs result in the response + if ( + hasattr(response, 'candidates') + and response.candidates + and len(response.candidates) > 0 + and hasattr(response.candidates[0], 'logprobs_result') + and response.candidates[0].logprobs_result + ): + logprobs_result = response.candidates[0].logprobs_result + + # Get the chosen candidates (tokens actually selected by the model) + if ( + hasattr(logprobs_result, 'chosen_candidates') + and logprobs_result.chosen_candidates + and len(logprobs_result.chosen_candidates) > 0 + ): + # Get the first token's log probability + first_token = logprobs_result.chosen_candidates[0] + + if hasattr(first_token, 'log_probability') and hasattr( + first_token, 'token' + ): + # Convert log probability to probability (similar to OpenAI approach) + log_prob = first_token.log_probability + probability = math.exp(log_prob) + + # Check if the token indicates relevance (starts with "True" or similar) + token_text = first_token.token.strip().lower() + if token_text.startswith(('true', 't')): + scores.append(probability) + else: + # For "False" or other tokens, use 1 - probability + scores.append(1.0 - probability) + else: + # Fallback: try to get from top candidates + if ( + hasattr(logprobs_result, 'top_candidates') + and logprobs_result.top_candidates + and len(logprobs_result.top_candidates) > 0 + ): + top_step = logprobs_result.top_candidates[0] + if ( + hasattr(top_step, 'candidates') + and top_step.candidates + and len(top_step.candidates) > 0 + ): + # Look for "True" or "False" in top candidates + true_prob = 0.0 + false_prob = 0.0 + + for candidate in top_step.candidates: + if hasattr(candidate, 'token') and hasattr( + candidate, 'log_probability' + ): + token_text = candidate.token.strip().lower() + prob = math.exp(candidate.log_probability) + + if token_text.startswith(('true', 't')): + true_prob = max(true_prob, prob) + elif token_text.startswith(('false', 'f')): + false_prob = max(false_prob, prob) + + # Use the probability of "True" as the relevance score + scores.append(true_prob) + else: + scores.append(0.0) + else: + scores.append(0.0) + else: + scores.append(0.0) + else: + # Fallback: parse the response text if no logprobs available + if hasattr(response, 'text') and response.text: + response_text = response.text.strip().lower() + if response_text.startswith(('true', 't')): + scores.append(0.9) # High confidence for "True" + elif response_text.startswith(('false', 'f')): + scores.append(0.1) # Low confidence for "False" + else: + scores.append(0.0) + else: + scores.append(0.0) + + except (ValueError, AttributeError) as e: + logger.warning(f'Error parsing log probabilities from Gemini response: {e}') + scores.append(0.0) + + results = [(passage, score) for passage, score in zip(passages, scores, strict=True)] + results.sort(reverse=True, key=lambda x: x[1]) + return results + + except Exception as e: + # Check if it's a rate limit error based on Gemini API error codes + error_message = str(e).lower() + if ( + 'rate limit' in error_message + or 'quota' in error_message + or 'resource_exhausted' in error_message + or '429' in str(e) + ): + raise RateLimitError from e + + logger.error(f'Error in generating LLM response: {e}') + raise diff --git a/graphiti_core/llm_client/gemini_client.py b/graphiti_core/llm_client/gemini_client.py index 2acd3866..bcdb2d8b 100644 --- a/graphiti_core/llm_client/gemini_client.py +++ b/graphiti_core/llm_client/gemini_client.py @@ -17,19 +17,21 @@ limitations under the License. import json import logging import typing +from typing import ClassVar from google import genai # type: ignore from google.genai import types # type: ignore from pydantic import BaseModel from ..prompts.models import Message -from .client import LLMClient +from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize from .errors import RateLimitError logger = logging.getLogger(__name__) DEFAULT_MODEL = 'gemini-2.5-flash' +DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17' class GeminiClient(LLMClient): @@ -52,6 +54,9 @@ class GeminiClient(LLMClient): Generates a response from the language model based on the provided messages. """ + # Class-level constants + MAX_RETRIES: ClassVar[int] = 2 + def __init__( self, config: LLMConfig | None = None, @@ -82,6 +87,49 @@ class GeminiClient(LLMClient): self.max_tokens = max_tokens self.thinking_config = thinking_config + def _check_safety_blocks(self, response) -> None: + """Check if response was blocked for safety reasons and raise appropriate exceptions.""" + # Check if the response was blocked for safety reasons + if not (hasattr(response, 'candidates') and response.candidates): + return + + candidate = response.candidates[0] + if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'): + return + + # Content was blocked for safety reasons - collect safety details + safety_info = [] + safety_ratings = getattr(candidate, 'safety_ratings', None) + + if safety_ratings: + for rating in safety_ratings: + if getattr(rating, 'blocked', False): + category = getattr(rating, 'category', 'Unknown') + probability = getattr(rating, 'probability', 'Unknown') + safety_info.append(f'{category}: {probability}') + + safety_details = ( + ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons' + ) + raise Exception(f'Response blocked by Gemini safety filters: {safety_details}') + + def _check_prompt_blocks(self, response) -> None: + """Check if prompt was blocked and raise appropriate exceptions.""" + prompt_feedback = getattr(response, 'prompt_feedback', None) + if not prompt_feedback: + return + + block_reason = getattr(prompt_feedback, 'block_reason', None) + if block_reason: + raise Exception(f'Prompt blocked by Gemini: {block_reason}') + + def _get_model_for_size(self, model_size: ModelSize) -> str: + """Get the appropriate model name based on the requested size.""" + if model_size == ModelSize.small: + return self.small_model or DEFAULT_SMALL_MODEL + else: + return self.model or DEFAULT_MODEL + async def _generate_response( self, messages: list[Message], @@ -96,14 +144,14 @@ class GeminiClient(LLMClient): messages (list[Message]): A list of messages to send to the language model. response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into. max_tokens (int): The maximum number of tokens to generate in the response. + model_size (ModelSize): The size of the model to use (small or medium). Returns: dict[str, typing.Any]: The response from the language model. Raises: RateLimitError: If the API rate limit is exceeded. - RefusalError: If the content is blocked by the model. - Exception: If there is an error generating the response. + Exception: If there is an error generating the response or content is blocked. """ try: gemini_messages: list[types.Content] = [] @@ -132,6 +180,9 @@ class GeminiClient(LLMClient): types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)]) ) + # Get the appropriate model for the requested size + model = self._get_model_for_size(model_size) + # Create generation config generation_config = types.GenerateContentConfig( temperature=self.temperature, @@ -144,11 +195,15 @@ class GeminiClient(LLMClient): # Generate content using the simple string approach response = await self.client.aio.models.generate_content( - model=self.model or DEFAULT_MODEL, + model=model, contents=gemini_messages, config=generation_config, ) + # Check for safety and prompt blocks + self._check_safety_blocks(response) + self._check_prompt_blocks(response) + # If this was a structured output request, parse the response into the Pydantic model if response_model is not None: try: @@ -166,9 +221,16 @@ class GeminiClient(LLMClient): return {'content': response.text} except Exception as e: - # Check if it's a rate limit error - if 'rate limit' in str(e).lower() or 'quota' in str(e).lower(): + # Check if it's a rate limit error based on Gemini API error codes + error_message = str(e).lower() + if ( + 'rate limit' in error_message + or 'quota' in error_message + or 'resource_exhausted' in error_message + or '429' in str(e) + ): raise RateLimitError from e + logger.error(f'Error in generating LLM response: {e}') raise @@ -180,13 +242,14 @@ class GeminiClient(LLMClient): model_size: ModelSize = ModelSize.medium, ) -> dict[str, typing.Any]: """ - Generate a response from the Gemini language model. - This method overrides the parent class method to provide a direct implementation. + Generate a response from the Gemini language model with retry logic and error handling. + This method overrides the parent class method to provide a direct implementation with advanced retry logic. Args: messages (list[Message]): A list of messages to send to the language model. response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into. - max_tokens (int): The maximum number of tokens to generate in the response. + max_tokens (int | None): The maximum number of tokens to generate in the response. + model_size (ModelSize): The size of the model to use (small or medium). Returns: dict[str, typing.Any]: The response from the language model. @@ -194,10 +257,53 @@ class GeminiClient(LLMClient): if max_tokens is None: max_tokens = self.max_tokens - # Call the internal _generate_response method - return await self._generate_response( - messages=messages, - response_model=response_model, - max_tokens=max_tokens, - model_size=model_size, - ) + retry_count = 0 + last_error = None + + # Add multilingual extraction instructions + messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES + + while retry_count <= self.MAX_RETRIES: + try: + response = await self._generate_response( + messages=messages, + response_model=response_model, + max_tokens=max_tokens, + model_size=model_size, + ) + return response + except RateLimitError: + # Rate limit errors should not trigger retries (fail fast) + raise + except Exception as e: + last_error = e + + # Check if this is a safety block - these typically shouldn't be retried + if 'safety' in str(e).lower() or 'blocked' in str(e).lower(): + logger.warning(f'Content blocked by safety filters: {e}') + raise + + # Don't retry if we've hit the max retries + if retry_count >= self.MAX_RETRIES: + logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}') + raise + + retry_count += 1 + + # Construct a detailed error message for the LLM + error_context = ( + f'The previous response attempt was invalid. ' + f'Error type: {e.__class__.__name__}. ' + f'Error details: {str(e)}. ' + f'Please try again with a valid response, ensuring the output matches ' + f'the expected format and constraints.' + ) + + error_message = Message(role='user', content=error_context) + messages.append(error_message) + logger.warning( + f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}' + ) + + # If we somehow get here, raise the last error + raise last_error or Exception('Max retries exceeded with no specific error') diff --git a/uv.lock b/uv.lock index 73c3fceb..2cd74414 100644 --- a/uv.lock +++ b/uv.lock @@ -266,6 +266,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, ] +[[package]] +name = "backoff" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148 }, +] + [[package]] name = "beautifulsoup4" version = "4.13.4" @@ -531,7 +540,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 } wheels = [ @@ -729,13 +738,14 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.13.2" +version = "0.14.0" source = { editable = "." } dependencies = [ { name = "diskcache" }, { name = "neo4j" }, { name = "numpy" }, { name = "openai" }, + { name = "posthog" }, { name = "pydantic" }, { name = "python-dotenv" }, { name = "tenacity" }, @@ -796,6 +806,7 @@ requires-dist = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "numpy", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.91.0" }, + { name = "posthog", specifier = ">=3.0.0" }, { name = "pydantic", specifier = ">=2.11.5" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" }, @@ -2243,6 +2254,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, ] +[[package]] +name = "posthog" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backoff" }, + { name = "distro" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "six" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/c3/c83883af8cc5e3b45d1bee85edce546a4db369fb8dc8eb6339fad764178b/posthog-6.0.0.tar.gz", hash = "sha256:b7bfa0da03bd9240891885d3e44b747e62192ac9ee6da280f45320f4ad3479e0", size = 88066 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/ec/7a44533c9fe7046ffcfe48ca0e7472ada2633854f474be633f4afed7b044/posthog-6.0.0-py3-none-any.whl", hash = "sha256:01f5d11046a6267d4384f552e819f0f4a7dc885eb19f606c36d44d662df9ff89", size = 104945 }, +] + [[package]] name = "prometheus-client" version = "0.22.1"