Compare commits
4 commits
main
...
chore/gemi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb476ff0bd | ||
|
|
d57a3bb6c4 | ||
|
|
4cfd7baf36 | ||
|
|
756734be01 |
5 changed files with 393 additions and 28 deletions
14
README.md
14
README.md
|
|
@ -242,7 +242,6 @@ graphiti = Graphiti(
|
||||||
),
|
),
|
||||||
client=azure_openai_client
|
client=azure_openai_client
|
||||||
),
|
),
|
||||||
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
|
|
||||||
cross_encoder=OpenAIRerankerClient(
|
cross_encoder=OpenAIRerankerClient(
|
||||||
llm_config=azure_llm_config,
|
llm_config=azure_llm_config,
|
||||||
client=azure_openai_client
|
client=azure_openai_client
|
||||||
|
|
@ -256,7 +255,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden
|
||||||
|
|
||||||
## Using Graphiti with Google Gemini
|
## Using Graphiti with Google Gemini
|
||||||
|
|
||||||
Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key.
|
Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key.
|
||||||
|
|
||||||
Install Graphiti:
|
Install Graphiti:
|
||||||
|
|
||||||
|
|
@ -272,6 +271,7 @@ pip install "graphiti-core[google-genai]"
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
|
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
|
||||||
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
|
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
|
||||||
|
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
|
||||||
|
|
||||||
# Google API key configuration
|
# Google API key configuration
|
||||||
api_key = "<your-google-api-key>"
|
api_key = "<your-google-api-key>"
|
||||||
|
|
@ -292,12 +292,20 @@ graphiti = Graphiti(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
embedding_model="embedding-001"
|
embedding_model="embedding-001"
|
||||||
)
|
)
|
||||||
|
),
|
||||||
|
cross_encoder=GeminiRerankerClient(
|
||||||
|
config=LLMConfig(
|
||||||
|
api_key=api_key,
|
||||||
|
model="gemini-2.5-flash-lite-preview-06-17"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now you can use Graphiti with Google Gemini
|
# Now you can use Graphiti with Google Gemini for all components
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance.
|
||||||
|
|
||||||
## Using Graphiti with Ollama (Local LLM)
|
## Using Graphiti with Ollama (Local LLM)
|
||||||
|
|
||||||
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.
|
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .client import CrossEncoderClient
|
from .client import CrossEncoderClient
|
||||||
|
from .gemini_reranker_client import GeminiRerankerClient
|
||||||
from .openai_reranker_client import OpenAIRerankerClient
|
from .openai_reranker_client import OpenAIRerankerClient
|
||||||
|
|
||||||
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
|
__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']
|
||||||
|
|
|
||||||
216
graphiti_core/cross_encoder/gemini_reranker_client.py
Normal file
216
graphiti_core/cross_encoder/gemini_reranker_client.py
Normal file
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google import genai # type: ignore
|
||||||
|
from google.genai import types # type: ignore
|
||||||
|
|
||||||
|
from ..helpers import semaphore_gather
|
||||||
|
from ..llm_client import LLMConfig, RateLimitError
|
||||||
|
from .client import CrossEncoderClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiRerankerClient(CrossEncoderClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
client: genai.Client | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the GeminiRerankerClient with the provided configuration and client.
|
||||||
|
|
||||||
|
This reranker uses the Gemini API to run a simple boolean classifier prompt concurrently
|
||||||
|
for each passage. Log-probabilities are used to rank the passages, equivalent to the OpenAI approach.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = LLMConfig()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
if client is None:
|
||||||
|
self.client = genai.Client(api_key=config.api_key)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Rank passages based on their relevance to the query using Gemini.
|
||||||
|
|
||||||
|
Uses log probabilities from boolean classification responses, equivalent to the OpenAI approach.
|
||||||
|
The model responds with "True" or "False" and we use the log probabilities of these tokens.
|
||||||
|
"""
|
||||||
|
gemini_messages_list: Any = [
|
||||||
|
[
|
||||||
|
types.Content(
|
||||||
|
role='system',
|
||||||
|
parts=[
|
||||||
|
types.Part.from_text(
|
||||||
|
text='You are an expert tasked with determining whether the passage is relevant to the query'
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
types.Content(
|
||||||
|
role='user',
|
||||||
|
parts=[
|
||||||
|
types.Part.from_text(
|
||||||
|
text=f"""Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||||
|
<PASSAGE>
|
||||||
|
{passage}
|
||||||
|
</PASSAGE>
|
||||||
|
<QUERY>
|
||||||
|
{query}
|
||||||
|
</QUERY>"""
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for passage in passages
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.client.aio.models.generate_content(
|
||||||
|
model=self.config.model or DEFAULT_MODEL,
|
||||||
|
contents=gemini_messages,
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
temperature=0.0,
|
||||||
|
max_output_tokens=1,
|
||||||
|
response_logprobs=True,
|
||||||
|
logprobs=5, # Get top 5 candidate tokens for better coverage
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for gemini_messages in gemini_messages_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
scores: list[float] = []
|
||||||
|
for response in responses:
|
||||||
|
try:
|
||||||
|
# Check if we have logprobs result in the response
|
||||||
|
if (
|
||||||
|
hasattr(response, 'candidates')
|
||||||
|
and response.candidates
|
||||||
|
and len(response.candidates) > 0
|
||||||
|
and hasattr(response.candidates[0], 'logprobs_result')
|
||||||
|
and response.candidates[0].logprobs_result
|
||||||
|
):
|
||||||
|
logprobs_result = response.candidates[0].logprobs_result
|
||||||
|
|
||||||
|
# Get the chosen candidates (tokens actually selected by the model)
|
||||||
|
if (
|
||||||
|
hasattr(logprobs_result, 'chosen_candidates')
|
||||||
|
and logprobs_result.chosen_candidates
|
||||||
|
and len(logprobs_result.chosen_candidates) > 0
|
||||||
|
):
|
||||||
|
# Get the first token's log probability
|
||||||
|
first_token = logprobs_result.chosen_candidates[0]
|
||||||
|
|
||||||
|
if hasattr(first_token, 'log_probability') and hasattr(
|
||||||
|
first_token, 'token'
|
||||||
|
):
|
||||||
|
# Convert log probability to probability (similar to OpenAI approach)
|
||||||
|
log_prob = first_token.log_probability
|
||||||
|
probability = math.exp(log_prob)
|
||||||
|
|
||||||
|
# Check if the token indicates relevance (starts with "True" or similar)
|
||||||
|
token_text = first_token.token.strip().lower()
|
||||||
|
if token_text.startswith(('true', 't')):
|
||||||
|
scores.append(probability)
|
||||||
|
else:
|
||||||
|
# For "False" or other tokens, use 1 - probability
|
||||||
|
scores.append(1.0 - probability)
|
||||||
|
else:
|
||||||
|
# Fallback: try to get from top candidates
|
||||||
|
if (
|
||||||
|
hasattr(logprobs_result, 'top_candidates')
|
||||||
|
and logprobs_result.top_candidates
|
||||||
|
and len(logprobs_result.top_candidates) > 0
|
||||||
|
):
|
||||||
|
top_step = logprobs_result.top_candidates[0]
|
||||||
|
if (
|
||||||
|
hasattr(top_step, 'candidates')
|
||||||
|
and top_step.candidates
|
||||||
|
and len(top_step.candidates) > 0
|
||||||
|
):
|
||||||
|
# Look for "True" or "False" in top candidates
|
||||||
|
true_prob = 0.0
|
||||||
|
false_prob = 0.0
|
||||||
|
|
||||||
|
for candidate in top_step.candidates:
|
||||||
|
if hasattr(candidate, 'token') and hasattr(
|
||||||
|
candidate, 'log_probability'
|
||||||
|
):
|
||||||
|
token_text = candidate.token.strip().lower()
|
||||||
|
prob = math.exp(candidate.log_probability)
|
||||||
|
|
||||||
|
if token_text.startswith(('true', 't')):
|
||||||
|
true_prob = max(true_prob, prob)
|
||||||
|
elif token_text.startswith(('false', 'f')):
|
||||||
|
false_prob = max(false_prob, prob)
|
||||||
|
|
||||||
|
# Use the probability of "True" as the relevance score
|
||||||
|
scores.append(true_prob)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
# Fallback: parse the response text if no logprobs available
|
||||||
|
if hasattr(response, 'text') and response.text:
|
||||||
|
response_text = response.text.strip().lower()
|
||||||
|
if response_text.startswith(('true', 't')):
|
||||||
|
scores.append(0.9) # High confidence for "True"
|
||||||
|
elif response_text.startswith(('false', 'f')):
|
||||||
|
scores.append(0.1) # Low confidence for "False"
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
logger.warning(f'Error parsing log probabilities from Gemini response: {e}')
|
||||||
|
scores.append(0.0)
|
||||||
|
|
||||||
|
results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
|
||||||
|
results.sort(reverse=True, key=lambda x: x[1])
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Check if it's a rate limit error based on Gemini API error codes
|
||||||
|
error_message = str(e).lower()
|
||||||
|
if (
|
||||||
|
'rate limit' in error_message
|
||||||
|
or 'quota' in error_message
|
||||||
|
or 'resource_exhausted' in error_message
|
||||||
|
or '429' in str(e)
|
||||||
|
):
|
||||||
|
raise RateLimitError from e
|
||||||
|
|
||||||
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
|
raise
|
||||||
|
|
@ -17,19 +17,21 @@ limitations under the License.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from google import genai # type: ignore
|
from google import genai # type: ignore
|
||||||
from google.genai import types # type: ignore
|
from google.genai import types # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL = 'gemini-2.0-flash'
|
DEFAULT_MODEL = 'gemini-2.5-flash'
|
||||||
|
DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
|
||||||
|
|
||||||
|
|
||||||
class GeminiClient(LLMClient):
|
class GeminiClient(LLMClient):
|
||||||
|
|
@ -43,27 +45,34 @@ class GeminiClient(LLMClient):
|
||||||
model (str): The model name to use for generating responses.
|
model (str): The model name to use for generating responses.
|
||||||
temperature (float): The temperature to use for generating responses.
|
temperature (float): The temperature to use for generating responses.
|
||||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||||
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||||
Methods:
|
Methods:
|
||||||
__init__(config: LLMConfig | None = None, cache: bool = False):
|
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
|
||||||
Initializes the GeminiClient with the provided configuration and cache setting.
|
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||||
|
|
||||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
Generates a response from the language model based on the provided messages.
|
Generates a response from the language model based on the provided messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Class-level constants
|
||||||
|
MAX_RETRIES: ClassVar[int] = 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
cache: bool = False,
|
cache: bool = False,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
thinking_config: types.ThinkingConfig | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the GeminiClient with the provided configuration and cache setting.
|
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
||||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||||
|
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
@ -76,6 +85,50 @@ class GeminiClient(LLMClient):
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
)
|
)
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.thinking_config = thinking_config
|
||||||
|
|
||||||
|
def _check_safety_blocks(self, response) -> None:
|
||||||
|
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
|
||||||
|
# Check if the response was blocked for safety reasons
|
||||||
|
if not (hasattr(response, 'candidates') and response.candidates):
|
||||||
|
return
|
||||||
|
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Content was blocked for safety reasons - collect safety details
|
||||||
|
safety_info = []
|
||||||
|
safety_ratings = getattr(candidate, 'safety_ratings', None)
|
||||||
|
|
||||||
|
if safety_ratings:
|
||||||
|
for rating in safety_ratings:
|
||||||
|
if getattr(rating, 'blocked', False):
|
||||||
|
category = getattr(rating, 'category', 'Unknown')
|
||||||
|
probability = getattr(rating, 'probability', 'Unknown')
|
||||||
|
safety_info.append(f'{category}: {probability}')
|
||||||
|
|
||||||
|
safety_details = (
|
||||||
|
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
|
||||||
|
)
|
||||||
|
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
|
||||||
|
|
||||||
|
def _check_prompt_blocks(self, response) -> None:
|
||||||
|
"""Check if prompt was blocked and raise appropriate exceptions."""
|
||||||
|
prompt_feedback = getattr(response, 'prompt_feedback', None)
|
||||||
|
if not prompt_feedback:
|
||||||
|
return
|
||||||
|
|
||||||
|
block_reason = getattr(prompt_feedback, 'block_reason', None)
|
||||||
|
if block_reason:
|
||||||
|
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
|
||||||
|
|
||||||
|
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
||||||
|
"""Get the appropriate model name based on the requested size."""
|
||||||
|
if model_size == ModelSize.small:
|
||||||
|
return self.small_model or DEFAULT_SMALL_MODEL
|
||||||
|
else:
|
||||||
|
return self.model or DEFAULT_MODEL
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -91,14 +144,14 @@ class GeminiClient(LLMClient):
|
||||||
messages (list[Message]): A list of messages to send to the language model.
|
messages (list[Message]): A list of messages to send to the language model.
|
||||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||||
max_tokens (int): The maximum number of tokens to generate in the response.
|
max_tokens (int): The maximum number of tokens to generate in the response.
|
||||||
|
model_size (ModelSize): The size of the model to use (small or medium).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, typing.Any]: The response from the language model.
|
dict[str, typing.Any]: The response from the language model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RateLimitError: If the API rate limit is exceeded.
|
RateLimitError: If the API rate limit is exceeded.
|
||||||
RefusalError: If the content is blocked by the model.
|
Exception: If there is an error generating the response or content is blocked.
|
||||||
Exception: If there is an error generating the response.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
gemini_messages: list[types.Content] = []
|
gemini_messages: list[types.Content] = []
|
||||||
|
|
@ -127,6 +180,9 @@ class GeminiClient(LLMClient):
|
||||||
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get the appropriate model for the requested size
|
||||||
|
model = self._get_model_for_size(model_size)
|
||||||
|
|
||||||
# Create generation config
|
# Create generation config
|
||||||
generation_config = types.GenerateContentConfig(
|
generation_config = types.GenerateContentConfig(
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
|
|
@ -134,15 +190,20 @@ class GeminiClient(LLMClient):
|
||||||
response_mime_type='application/json' if response_model else None,
|
response_mime_type='application/json' if response_model else None,
|
||||||
response_schema=response_model if response_model else None,
|
response_schema=response_model if response_model else None,
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
|
thinking_config=self.thinking_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate content using the simple string approach
|
# Generate content using the simple string approach
|
||||||
response = await self.client.aio.models.generate_content(
|
response = await self.client.aio.models.generate_content(
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=model,
|
||||||
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
|
contents=gemini_messages,
|
||||||
config=generation_config,
|
config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check for safety and prompt blocks
|
||||||
|
self._check_safety_blocks(response)
|
||||||
|
self._check_prompt_blocks(response)
|
||||||
|
|
||||||
# If this was a structured output request, parse the response into the Pydantic model
|
# If this was a structured output request, parse the response into the Pydantic model
|
||||||
if response_model is not None:
|
if response_model is not None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -160,9 +221,16 @@ class GeminiClient(LLMClient):
|
||||||
return {'content': response.text}
|
return {'content': response.text}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Check if it's a rate limit error
|
# Check if it's a rate limit error based on Gemini API error codes
|
||||||
if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
|
error_message = str(e).lower()
|
||||||
|
if (
|
||||||
|
'rate limit' in error_message
|
||||||
|
or 'quota' in error_message
|
||||||
|
or 'resource_exhausted' in error_message
|
||||||
|
or '429' in str(e)
|
||||||
|
):
|
||||||
raise RateLimitError from e
|
raise RateLimitError from e
|
||||||
|
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
@ -174,13 +242,14 @@ class GeminiClient(LLMClient):
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the Gemini language model.
|
Generate a response from the Gemini language model with retry logic and error handling.
|
||||||
This method overrides the parent class method to provide a direct implementation.
|
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (list[Message]): A list of messages to send to the language model.
|
messages (list[Message]): A list of messages to send to the language model.
|
||||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||||
max_tokens (int): The maximum number of tokens to generate in the response.
|
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
||||||
|
model_size (ModelSize): The size of the model to use (small or medium).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, typing.Any]: The response from the language model.
|
dict[str, typing.Any]: The response from the language model.
|
||||||
|
|
@ -188,10 +257,53 @@ class GeminiClient(LLMClient):
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
||||||
# Call the internal _generate_response method
|
retry_count = 0
|
||||||
return await self._generate_response(
|
last_error = None
|
||||||
messages=messages,
|
|
||||||
response_model=response_model,
|
# Add multilingual extraction instructions
|
||||||
max_tokens=max_tokens,
|
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
||||||
model_size=model_size,
|
|
||||||
)
|
while retry_count <= self.MAX_RETRIES:
|
||||||
|
try:
|
||||||
|
response = await self._generate_response(
|
||||||
|
messages=messages,
|
||||||
|
response_model=response_model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
model_size=model_size,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except RateLimitError:
|
||||||
|
# Rate limit errors should not trigger retries (fail fast)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
# Check if this is a safety block - these typically shouldn't be retried
|
||||||
|
if 'safety' in str(e).lower() or 'blocked' in str(e).lower():
|
||||||
|
logger.warning(f'Content blocked by safety filters: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Don't retry if we've hit the max retries
|
||||||
|
if retry_count >= self.MAX_RETRIES:
|
||||||
|
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# Construct a detailed error message for the LLM
|
||||||
|
error_context = (
|
||||||
|
f'The previous response attempt was invalid. '
|
||||||
|
f'Error type: {e.__class__.__name__}. '
|
||||||
|
f'Error details: {str(e)}. '
|
||||||
|
f'Please try again with a valid response, ensuring the output matches '
|
||||||
|
f'the expected format and constraints.'
|
||||||
|
)
|
||||||
|
|
||||||
|
error_message = Message(role='user', content=error_context)
|
||||||
|
messages.append(error_message)
|
||||||
|
logger.warning(
|
||||||
|
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we somehow get here, raise the last error
|
||||||
|
raise last_error or Exception('Max retries exceeded with no specific error')
|
||||||
|
|
|
||||||
32
uv.lock
generated
32
uv.lock
generated
|
|
@ -266,6 +266,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 },
|
{ url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "backoff"
|
||||||
|
version = "2.2.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "beautifulsoup4"
|
name = "beautifulsoup4"
|
||||||
version = "4.13.4"
|
version = "4.13.4"
|
||||||
|
|
@ -531,7 +540,7 @@ name = "exceptiongroup"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 }
|
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
|
@ -729,13 +738,14 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.13.2"
|
version = "0.14.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
{ name = "neo4j" },
|
{ name = "neo4j" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
|
{ name = "posthog" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "tenacity" },
|
{ name = "tenacity" },
|
||||||
|
|
@ -796,6 +806,7 @@ requires-dist = [
|
||||||
{ name = "neo4j", specifier = ">=5.26.0" },
|
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||||
{ name = "numpy", specifier = ">=1.0.0" },
|
{ name = "numpy", specifier = ">=1.0.0" },
|
||||||
{ name = "openai", specifier = ">=1.91.0" },
|
{ name = "openai", specifier = ">=1.91.0" },
|
||||||
|
{ name = "posthog", specifier = ">=3.0.0" },
|
||||||
{ name = "pydantic", specifier = ">=2.11.5" },
|
{ name = "pydantic", specifier = ">=2.11.5" },
|
||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
|
||||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
|
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
|
||||||
|
|
@ -2243,6 +2254,23 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 },
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "posthog"
|
||||||
|
version = "6.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "backoff" },
|
||||||
|
{ name = "distro" },
|
||||||
|
{ name = "python-dateutil" },
|
||||||
|
{ name = "requests" },
|
||||||
|
{ name = "six" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f3/c3/c83883af8cc5e3b45d1bee85edce546a4db369fb8dc8eb6339fad764178b/posthog-6.0.0.tar.gz", hash = "sha256:b7bfa0da03bd9240891885d3e44b747e62192ac9ee6da280f45320f4ad3479e0", size = 88066 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ab/ec/7a44533c9fe7046ffcfe48ca0e7472ada2633854f474be633f4afed7b044/posthog-6.0.0-py3-none-any.whl", hash = "sha256:01f5d11046a6267d4384f552e819f0f4a7dc885eb19f606c36d44d662df9ff89", size = 104945 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prometheus-client"
|
name = "prometheus-client"
|
||||||
version = "0.22.1"
|
version = "0.22.1"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue