* Add OpenTelemetry distributed tracing support - Add tracer abstraction with no-op and OpenTelemetry implementations - Instrument add_episode and add_episode_bulk with tracing spans - Instrument LLM client with cache-aware tracing - Add configurable span name prefix support - Refactor add_episode methods to improve code quality - Add OTEL_TRACING.md documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix linting errors in tracing implementation - Remove unused episodes_by_uuid variable - Fix tracer type annotations for context manager support - Replace isinstance tuple with union syntax - Use contextlib.suppress for exception handling - Fix import ordering and use AbstractContextManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Address PR review feedback on tracing implementation Critical fixes: - Remove flawed error span creation in graphiti.py that created orphaned spans - Restructure LLM client tracing to create span once at start, eliminating code duplication - Initialize LLM client tracer to NoOpTracer by default to fix type checking Enhancements: - Add comprehensive span attributes to add_episode: reference_time, entity/edge type counts, previous episodes count, invalidated edge count, community count - Optimize isinstance check for better performance 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add prompt name tracking to OpenTelemetry tracing spans Add prompt_name parameter to all LLM client generate_response() methods and set it as a span attribute in the llm.generate span. This enables better observability by identifying which prompt template was used for each LLM call. Changes: - Add prompt_name parameter to LLMClient.generate_response() base method - Add prompt_name parameter and tracing to OpenAIBaseClient, AnthropicClient, GeminiClient, and OpenAIGenericClient - Update all 14 LLM call sites across maintenance operations to include prompt_name: - edge_operations.py: 4 calls - node_operations.py: 6 calls (note: 7 listed but only 6 unique) - temporal_operations.py: 2 calls - community_operations.py: 2 calls 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix exception handling in add_episode to record errors in OpenTelemetry span Moved try-except block inside the OpenTelemetry span context and added proper error recording with span.set_status() and span.record_exception(). This ensures exceptions are captured in the distributed trace, matching the pattern used in add_episode_bulk. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
447 lines
18 KiB
Python
447 lines
18 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import typing
|
|
from typing import TYPE_CHECKING, ClassVar
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from ..prompts.models import Message
|
|
from .client import LLMClient, get_extraction_language_instruction
|
|
from .config import LLMConfig, ModelSize
|
|
from .errors import RateLimitError
|
|
|
|
if TYPE_CHECKING:
|
|
from google import genai
|
|
from google.genai import types
|
|
else:
|
|
try:
|
|
from google import genai
|
|
from google.genai import types
|
|
except ImportError:
|
|
# If gemini client is not installed, raise an ImportError
|
|
raise ImportError(
|
|
'google-genai is required for GeminiClient. '
|
|
'Install it with: pip install graphiti-core[google-genai]'
|
|
) from None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL = 'gemini-2.5-flash'
|
|
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
|
|
|
# Maximum output tokens for different Gemini models
|
|
GEMINI_MODEL_MAX_TOKENS = {
|
|
# Gemini 2.5 models
|
|
'gemini-2.5-pro': 65536,
|
|
'gemini-2.5-flash': 65536,
|
|
'gemini-2.5-flash-lite': 64000,
|
|
'models/gemini-2.5-flash-lite-preview-06-17': 64000,
|
|
# Gemini 2.0 models
|
|
'gemini-2.0-flash': 8192,
|
|
'gemini-2.0-flash-lite': 8192,
|
|
# Gemini 1.5 models
|
|
'gemini-1.5-pro': 8192,
|
|
'gemini-1.5-flash': 8192,
|
|
'gemini-1.5-flash-8b': 8192,
|
|
}
|
|
|
|
# Default max tokens for models not in the mapping
|
|
DEFAULT_GEMINI_MAX_TOKENS = 8192
|
|
|
|
|
|
class GeminiClient(LLMClient):
|
|
"""
|
|
GeminiClient is a client class for interacting with Google's Gemini language models.
|
|
|
|
This class extends the LLMClient and provides methods to initialize the client
|
|
and generate responses from the Gemini language model.
|
|
|
|
Attributes:
|
|
model (str): The model name to use for generating responses.
|
|
temperature (float): The temperature to use for generating responses.
|
|
max_tokens (int): The maximum number of tokens to generate in a response.
|
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
Methods:
|
|
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
|
|
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
|
|
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
Generates a response from the language model based on the provided messages.
|
|
"""
|
|
|
|
# Class-level constants
|
|
MAX_RETRIES: ClassVar[int] = 2
|
|
|
|
def __init__(
|
|
self,
|
|
config: LLMConfig | None = None,
|
|
cache: bool = False,
|
|
max_tokens: int | None = None,
|
|
thinking_config: types.ThinkingConfig | None = None,
|
|
client: 'genai.Client | None' = None,
|
|
):
|
|
"""
|
|
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
|
|
Args:
|
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
"""
|
|
if config is None:
|
|
config = LLMConfig()
|
|
|
|
super().__init__(config, cache)
|
|
|
|
self.model = config.model
|
|
|
|
if client is None:
|
|
self.client = genai.Client(api_key=config.api_key)
|
|
else:
|
|
self.client = client
|
|
|
|
self.max_tokens = max_tokens
|
|
self.thinking_config = thinking_config
|
|
|
|
def _check_safety_blocks(self, response) -> None:
|
|
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
|
|
# Check if the response was blocked for safety reasons
|
|
if not (hasattr(response, 'candidates') and response.candidates):
|
|
return
|
|
|
|
candidate = response.candidates[0]
|
|
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
|
|
return
|
|
|
|
# Content was blocked for safety reasons - collect safety details
|
|
safety_info = []
|
|
safety_ratings = getattr(candidate, 'safety_ratings', None)
|
|
|
|
if safety_ratings:
|
|
for rating in safety_ratings:
|
|
if getattr(rating, 'blocked', False):
|
|
category = getattr(rating, 'category', 'Unknown')
|
|
probability = getattr(rating, 'probability', 'Unknown')
|
|
safety_info.append(f'{category}: {probability}')
|
|
|
|
safety_details = (
|
|
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
|
|
)
|
|
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
|
|
|
|
def _check_prompt_blocks(self, response) -> None:
|
|
"""Check if prompt was blocked and raise appropriate exceptions."""
|
|
prompt_feedback = getattr(response, 'prompt_feedback', None)
|
|
if not prompt_feedback:
|
|
return
|
|
|
|
block_reason = getattr(prompt_feedback, 'block_reason', None)
|
|
if block_reason:
|
|
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
|
|
|
|
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
|
"""Get the appropriate model name based on the requested size."""
|
|
if model_size == ModelSize.small:
|
|
return self.small_model or DEFAULT_SMALL_MODEL
|
|
else:
|
|
return self.model or DEFAULT_MODEL
|
|
|
|
def _get_max_tokens_for_model(self, model: str) -> int:
|
|
"""Get the maximum output tokens for a specific Gemini model."""
|
|
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
|
|
|
|
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
|
"""
|
|
Resolve the maximum output tokens to use based on precedence rules.
|
|
|
|
Precedence order (highest to lowest):
|
|
1. Explicit max_tokens parameter passed to generate_response()
|
|
2. Instance max_tokens set during client initialization
|
|
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
|
|
4. DEFAULT_MAX_TOKENS as final fallback
|
|
|
|
Args:
|
|
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
|
model: The model name to look up model-specific limits
|
|
|
|
Returns:
|
|
int: The resolved maximum tokens to use
|
|
"""
|
|
# 1. Use explicit parameter if provided
|
|
if requested_max_tokens is not None:
|
|
return requested_max_tokens
|
|
|
|
# 2. Use instance max_tokens if set during initialization
|
|
if self.max_tokens is not None:
|
|
return self.max_tokens
|
|
|
|
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
|
|
return self._get_max_tokens_for_model(model)
|
|
|
|
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
|
|
"""
|
|
Attempt to salvage a JSON object if the raw output is truncated.
|
|
|
|
This is accomplished by looking for the last closing bracket for an array or object.
|
|
If found, it will try to load the JSON object from the raw output.
|
|
If the JSON object is not valid, it will return None.
|
|
|
|
Args:
|
|
raw_output (str): The raw output from the LLM.
|
|
|
|
Returns:
|
|
dict[str, typing.Any]: The salvaged JSON object.
|
|
None: If no salvage is possible.
|
|
"""
|
|
if not raw_output:
|
|
return None
|
|
# Try to salvage a JSON array
|
|
array_match = re.search(r'\]\s*$', raw_output)
|
|
if array_match:
|
|
try:
|
|
return json.loads(raw_output[: array_match.end()])
|
|
except Exception:
|
|
pass
|
|
# Try to salvage a JSON object
|
|
obj_match = re.search(r'\}\s*$', raw_output)
|
|
if obj_match:
|
|
try:
|
|
return json.loads(raw_output[: obj_match.end()])
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
async def _generate_response(
|
|
self,
|
|
messages: list[Message],
|
|
response_model: type[BaseModel] | None = None,
|
|
max_tokens: int | None = None,
|
|
model_size: ModelSize = ModelSize.medium,
|
|
) -> dict[str, typing.Any]:
|
|
"""
|
|
Generate a response from the Gemini language model.
|
|
|
|
Args:
|
|
messages (list[Message]): A list of messages to send to the language model.
|
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
|
|
model_size (ModelSize): The size of the model to use (small or medium).
|
|
|
|
Returns:
|
|
dict[str, typing.Any]: The response from the language model.
|
|
|
|
Raises:
|
|
RateLimitError: If the API rate limit is exceeded.
|
|
Exception: If there is an error generating the response or content is blocked.
|
|
"""
|
|
try:
|
|
gemini_messages: typing.Any = []
|
|
# If a response model is provided, add schema for structured output
|
|
system_prompt = ''
|
|
if response_model is not None:
|
|
# Get the schema from the Pydantic model
|
|
pydantic_schema = response_model.model_json_schema()
|
|
|
|
# Create instruction to output in the desired JSON format
|
|
system_prompt += (
|
|
f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
|
|
'Do not include any explanatory text before or after the JSON.\n\n'
|
|
)
|
|
|
|
# Add messages content
|
|
# First check for a system message
|
|
if messages and messages[0].role == 'system':
|
|
system_prompt = f'{messages[0].content}\n\n {system_prompt}'
|
|
messages = messages[1:]
|
|
|
|
# Add the rest of the messages
|
|
for m in messages:
|
|
m.content = self._clean_input(m.content)
|
|
gemini_messages.append(
|
|
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
|
)
|
|
|
|
# Get the appropriate model for the requested size
|
|
model = self._get_model_for_size(model_size)
|
|
|
|
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
|
|
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
|
|
|
|
# Create generation config
|
|
generation_config = types.GenerateContentConfig(
|
|
temperature=self.temperature,
|
|
max_output_tokens=resolved_max_tokens,
|
|
response_mime_type='application/json' if response_model else None,
|
|
response_schema=response_model if response_model else None,
|
|
system_instruction=system_prompt,
|
|
thinking_config=self.thinking_config,
|
|
)
|
|
|
|
# Generate content using the simple string approach
|
|
response = await self.client.aio.models.generate_content(
|
|
model=model,
|
|
contents=gemini_messages,
|
|
config=generation_config,
|
|
)
|
|
|
|
# Always capture the raw output for debugging
|
|
raw_output = getattr(response, 'text', None)
|
|
|
|
# Check for safety and prompt blocks
|
|
self._check_safety_blocks(response)
|
|
self._check_prompt_blocks(response)
|
|
|
|
# If this was a structured output request, parse the response into the Pydantic model
|
|
if response_model is not None:
|
|
try:
|
|
if not raw_output:
|
|
raise ValueError('No response text')
|
|
|
|
validated_model = response_model.model_validate(json.loads(raw_output))
|
|
|
|
# Return as a dictionary for API consistency
|
|
return validated_model.model_dump()
|
|
except Exception as e:
|
|
if raw_output:
|
|
logger.error(
|
|
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
|
|
)
|
|
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
|
|
# Try to salvage
|
|
salvaged = self.salvage_json(raw_output)
|
|
if salvaged is not None:
|
|
logger.warning('Salvaged partial JSON from truncated/malformed output.')
|
|
return salvaged
|
|
raise Exception(f'Failed to parse structured response: {e}') from e
|
|
|
|
# Otherwise, return the response text as a dictionary
|
|
return {'content': raw_output}
|
|
|
|
except Exception as e:
|
|
# Check if it's a rate limit error based on Gemini API error codes
|
|
error_message = str(e).lower()
|
|
if (
|
|
'rate limit' in error_message
|
|
or 'quota' in error_message
|
|
or 'resource_exhausted' in error_message
|
|
or '429' in str(e)
|
|
):
|
|
raise RateLimitError from e
|
|
|
|
logger.error(f'Error in generating LLM response: {e}')
|
|
raise Exception from e
|
|
|
|
async def generate_response(
|
|
self,
|
|
messages: list[Message],
|
|
response_model: type[BaseModel] | None = None,
|
|
max_tokens: int | None = None,
|
|
model_size: ModelSize = ModelSize.medium,
|
|
group_id: str | None = None,
|
|
prompt_name: str | None = None,
|
|
) -> dict[str, typing.Any]:
|
|
"""
|
|
Generate a response from the Gemini language model with retry logic and error handling.
|
|
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
|
|
|
|
Args:
|
|
messages (list[Message]): A list of messages to send to the language model.
|
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
|
model_size (ModelSize): The size of the model to use (small or medium).
|
|
group_id (str | None): Optional partition identifier for the graph.
|
|
prompt_name (str | None): Optional name of the prompt for tracing.
|
|
|
|
Returns:
|
|
dict[str, typing.Any]: The response from the language model.
|
|
"""
|
|
# Add multilingual extraction instructions
|
|
messages[0].content += get_extraction_language_instruction(group_id)
|
|
|
|
# Wrap entire operation in tracing span
|
|
with self.tracer.start_span('llm.generate') as span:
|
|
attributes = {
|
|
'llm.provider': 'gemini',
|
|
'model.size': model_size.value,
|
|
'max_tokens': max_tokens or self.max_tokens,
|
|
}
|
|
if prompt_name:
|
|
attributes['prompt.name'] = prompt_name
|
|
span.add_attributes(attributes)
|
|
|
|
retry_count = 0
|
|
last_error = None
|
|
last_output = None
|
|
|
|
while retry_count < self.MAX_RETRIES:
|
|
try:
|
|
response = await self._generate_response(
|
|
messages=messages,
|
|
response_model=response_model,
|
|
max_tokens=max_tokens,
|
|
model_size=model_size,
|
|
)
|
|
last_output = (
|
|
response.get('content')
|
|
if isinstance(response, dict) and 'content' in response
|
|
else None
|
|
)
|
|
return response
|
|
except RateLimitError as e:
|
|
# Rate limit errors should not trigger retries (fail fast)
|
|
span.set_status('error', str(e))
|
|
raise e
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
# Check if this is a safety block - these typically shouldn't be retried
|
|
error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
|
|
if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
|
|
logger.warning(f'Content blocked by safety filters: {e}')
|
|
span.set_status('error', str(e))
|
|
raise Exception(f'Content blocked by safety filters: {e}') from e
|
|
|
|
retry_count += 1
|
|
|
|
# Construct a detailed error message for the LLM
|
|
error_context = (
|
|
f'The previous response attempt was invalid. '
|
|
f'Error type: {e.__class__.__name__}. '
|
|
f'Error details: {str(e)}. '
|
|
f'Please try again with a valid response, ensuring the output matches '
|
|
f'the expected format and constraints.'
|
|
)
|
|
|
|
error_message = Message(role='user', content=error_context)
|
|
messages.append(error_message)
|
|
logger.warning(
|
|
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
)
|
|
|
|
# If we exit the loop without returning, all retries are exhausted
|
|
logger.error('🦀 LLM generation failed and retries are exhausted.')
|
|
logger.error(self._get_failed_generation_log(messages, last_output))
|
|
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
|
|
span.set_status('error', str(last_error))
|
|
span.record_exception(last_error) if last_error else None
|
|
raise last_error or Exception('Max retries exceeded')
|