Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
Daniel Chalef
bb476ff0bd improve client; add reranker 2025-06-28 11:37:34 -07:00
Daniel Chalef
d57a3bb6c4 merge 2025-06-27 18:21:40 -07:00
realugbun
4cfd7baf36 allow adding thinking config to support current and future gemini models 2025-06-27 18:13:53 -07:00
realugbun
756734be01 add support for Gemini 2.5 model thinking budget 2025-06-27 18:12:30 -07:00
5 changed files with 393 additions and 28 deletions

View file

@ -242,7 +242,6 @@ graphiti = Graphiti(
),
client=azure_openai_client
),
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
cross_encoder=OpenAIRerankerClient(
llm_config=azure_llm_config,
client=azure_openai_client
@ -256,7 +255,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden
## Using Graphiti with Google Gemini
Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key.
Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key.
Install Graphiti:
@ -272,6 +271,7 @@ pip install "graphiti-core[google-genai]"
from graphiti_core import Graphiti
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
# Google API key configuration
api_key = "<your-google-api-key>"
@ -292,12 +292,20 @@ graphiti = Graphiti(
api_key=api_key,
embedding_model="embedding-001"
)
),
cross_encoder=GeminiRerankerClient(
config=LLMConfig(
api_key=api_key,
model="gemini-2.5-flash-lite-preview-06-17"
)
)
)
# Now you can use Graphiti with Google Gemini
# Now you can use Graphiti with Google Gemini for all components
```
The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance.
## Using Graphiti with Ollama (Local LLM)
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.

View file

@ -15,6 +15,7 @@ limitations under the License.
"""
from .client import CrossEncoderClient
from .gemini_reranker_client import GeminiRerankerClient
from .openai_reranker_client import OpenAIRerankerClient
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']

View file

@ -0,0 +1,216 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import math
from typing import Any
from google import genai # type: ignore
from google.genai import types # type: ignore
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from .client import CrossEncoderClient
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
class GeminiRerankerClient(CrossEncoderClient):
def __init__(
self,
config: LLMConfig | None = None,
client: genai.Client | None = None,
):
"""
Initialize the GeminiRerankerClient with the provided configuration and client.
This reranker uses the Gemini API to run a simple boolean classifier prompt concurrently
for each passage. Log-probabilities are used to rank the passages, equivalent to the OpenAI approach.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
self.config = config
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank passages based on their relevance to the query using Gemini.
Uses log probabilities from boolean classification responses, equivalent to the OpenAI approach.
The model responds with "True" or "False" and we use the log probabilities of these tokens.
"""
gemini_messages_list: Any = [
[
types.Content(
role='system',
parts=[
types.Part.from_text(
text='You are an expert tasked with determining whether the passage is relevant to the query'
)
],
),
types.Content(
role='user',
parts=[
types.Part.from_text(
text=f"""Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
<PASSAGE>
{passage}
</PASSAGE>
<QUERY>
{query}
</QUERY>"""
)
],
),
]
for passage in passages
]
try:
responses = await semaphore_gather(
*[
self.client.aio.models.generate_content(
model=self.config.model or DEFAULT_MODEL,
contents=gemini_messages,
config=types.GenerateContentConfig(
temperature=0.0,
max_output_tokens=1,
response_logprobs=True,
logprobs=5, # Get top 5 candidate tokens for better coverage
),
)
for gemini_messages in gemini_messages_list
]
)
scores: list[float] = []
for response in responses:
try:
# Check if we have logprobs result in the response
if (
hasattr(response, 'candidates')
and response.candidates
and len(response.candidates) > 0
and hasattr(response.candidates[0], 'logprobs_result')
and response.candidates[0].logprobs_result
):
logprobs_result = response.candidates[0].logprobs_result
# Get the chosen candidates (tokens actually selected by the model)
if (
hasattr(logprobs_result, 'chosen_candidates')
and logprobs_result.chosen_candidates
and len(logprobs_result.chosen_candidates) > 0
):
# Get the first token's log probability
first_token = logprobs_result.chosen_candidates[0]
if hasattr(first_token, 'log_probability') and hasattr(
first_token, 'token'
):
# Convert log probability to probability (similar to OpenAI approach)
log_prob = first_token.log_probability
probability = math.exp(log_prob)
# Check if the token indicates relevance (starts with "True" or similar)
token_text = first_token.token.strip().lower()
if token_text.startswith(('true', 't')):
scores.append(probability)
else:
# For "False" or other tokens, use 1 - probability
scores.append(1.0 - probability)
else:
# Fallback: try to get from top candidates
if (
hasattr(logprobs_result, 'top_candidates')
and logprobs_result.top_candidates
and len(logprobs_result.top_candidates) > 0
):
top_step = logprobs_result.top_candidates[0]
if (
hasattr(top_step, 'candidates')
and top_step.candidates
and len(top_step.candidates) > 0
):
# Look for "True" or "False" in top candidates
true_prob = 0.0
false_prob = 0.0
for candidate in top_step.candidates:
if hasattr(candidate, 'token') and hasattr(
candidate, 'log_probability'
):
token_text = candidate.token.strip().lower()
prob = math.exp(candidate.log_probability)
if token_text.startswith(('true', 't')):
true_prob = max(true_prob, prob)
elif token_text.startswith(('false', 'f')):
false_prob = max(false_prob, prob)
# Use the probability of "True" as the relevance score
scores.append(true_prob)
else:
scores.append(0.0)
else:
scores.append(0.0)
else:
scores.append(0.0)
else:
# Fallback: parse the response text if no logprobs available
if hasattr(response, 'text') and response.text:
response_text = response.text.strip().lower()
if response_text.startswith(('true', 't')):
scores.append(0.9) # High confidence for "True"
elif response_text.startswith(('false', 'f')):
scores.append(0.1) # Low confidence for "False"
else:
scores.append(0.0)
else:
scores.append(0.0)
except (ValueError, AttributeError) as e:
logger.warning(f'Error parsing log probabilities from Gemini response: {e}')
scores.append(0.0)
results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
results.sort(reverse=True, key=lambda x: x[1])
return results
except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -17,19 +17,21 @@ limitations under the License.
import json
import logging
import typing
from typing import ClassVar
from google import genai # type: ignore
from google.genai import types # type: ignore
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.0-flash'
DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
class GeminiClient(LLMClient):
@ -43,27 +45,34 @@ class GeminiClient(LLMClient):
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False):
Initializes the GeminiClient with the provided configuration and cache setting.
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
thinking_config: types.ThinkingConfig | None = None,
):
"""
Initialize the GeminiClient with the provided configuration and cache setting.
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Only use with models that support thinking (gemini-2.5+). Defaults to None.
"""
if config is None:
config = LLMConfig()
@ -76,6 +85,50 @@ class GeminiClient(LLMClient):
api_key=config.api_key,
)
self.max_tokens = max_tokens
self.thinking_config = thinking_config
def _check_safety_blocks(self, response) -> None:
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
# Check if the response was blocked for safety reasons
if not (hasattr(response, 'candidates') and response.candidates):
return
candidate = response.candidates[0]
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
return
# Content was blocked for safety reasons - collect safety details
safety_info = []
safety_ratings = getattr(candidate, 'safety_ratings', None)
if safety_ratings:
for rating in safety_ratings:
if getattr(rating, 'blocked', False):
category = getattr(rating, 'category', 'Unknown')
probability = getattr(rating, 'probability', 'Unknown')
safety_info.append(f'{category}: {probability}')
safety_details = (
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
)
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
def _check_prompt_blocks(self, response) -> None:
"""Check if prompt was blocked and raise appropriate exceptions."""
prompt_feedback = getattr(response, 'prompt_feedback', None)
if not prompt_feedback:
return
block_reason = getattr(prompt_feedback, 'block_reason', None)
if block_reason:
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
def _get_model_for_size(self, model_size: ModelSize) -> str:
"""Get the appropriate model name based on the requested size."""
if model_size == ModelSize.small:
return self.small_model or DEFAULT_SMALL_MODEL
else:
return self.model or DEFAULT_MODEL
async def _generate_response(
self,
@ -91,14 +144,14 @@ class GeminiClient(LLMClient):
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
Raises:
RateLimitError: If the API rate limit is exceeded.
RefusalError: If the content is blocked by the model.
Exception: If there is an error generating the response.
Exception: If there is an error generating the response or content is blocked.
"""
try:
gemini_messages: list[types.Content] = []
@ -127,6 +180,9 @@ class GeminiClient(LLMClient):
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
)
# Get the appropriate model for the requested size
model = self._get_model_for_size(model_size)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
@ -134,15 +190,20 @@ class GeminiClient(LLMClient):
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
thinking_config=self.thinking_config,
)
# Generate content using the simple string approach
response = await self.client.aio.models.generate_content(
model=self.model or DEFAULT_MODEL,
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
model=model,
contents=gemini_messages,
config=generation_config,
)
# Check for safety and prompt blocks
self._check_safety_blocks(response)
self._check_prompt_blocks(response)
# If this was a structured output request, parse the response into the Pydantic model
if response_model is not None:
try:
@ -160,9 +221,16 @@ class GeminiClient(LLMClient):
return {'content': response.text}
except Exception as e:
# Check if it's a rate limit error
if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise
@ -174,13 +242,14 @@ class GeminiClient(LLMClient):
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
This method overrides the parent class method to provide a direct implementation.
Generate a response from the Gemini language model with retry logic and error handling.
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
@ -188,10 +257,53 @@ class GeminiClient(LLMClient):
if max_tokens is None:
max_tokens = self.max_tokens
# Call the internal _generate_response method
return await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
retry_count = 0
last_error = None
# Add multilingual extraction instructions
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
return response
except RateLimitError:
# Rate limit errors should not trigger retries (fail fast)
raise
except Exception as e:
last_error = e
# Check if this is a safety block - these typically shouldn't be retried
if 'safety' in str(e).lower() or 'blocked' in str(e).lower():
logger.warning(f'Content blocked by safety filters: {e}')
raise
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')

32
uv.lock generated
View file

@ -266,6 +266,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 },
]
[[package]]
name = "backoff"
version = "2.2.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148 },
]
[[package]]
name = "beautifulsoup4"
version = "4.13.4"
@ -531,7 +540,7 @@ name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 }
wheels = [
@ -729,13 +738,14 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.13.2"
version = "0.14.0"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },
{ name = "neo4j" },
{ name = "numpy" },
{ name = "openai" },
{ name = "posthog" },
{ name = "pydantic" },
{ name = "python-dotenv" },
{ name = "tenacity" },
@ -796,6 +806,7 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" },
{ name = "posthog", specifier = ">=3.0.0" },
{ name = "pydantic", specifier = ">=2.11.5" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
@ -2243,6 +2254,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 },
]
[[package]]
name = "posthog"
version = "6.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
{ name = "distro" },
{ name = "python-dateutil" },
{ name = "requests" },
{ name = "six" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f3/c3/c83883af8cc5e3b45d1bee85edce546a4db369fb8dc8eb6339fad764178b/posthog-6.0.0.tar.gz", hash = "sha256:b7bfa0da03bd9240891885d3e44b747e62192ac9ee6da280f45320f4ad3479e0", size = 88066 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/ec/7a44533c9fe7046ffcfe48ca0e7472ada2633854f474be633f4afed7b044/posthog-6.0.0-py3-none-any.whl", hash = "sha256:01f5d11046a6267d4384f552e819f0f4a7dc885eb19f606c36d44d662df9ff89", size = 104945 },
]
[[package]]
name = "prometheus-client"
version = "0.22.1"