save edge update (#721)

This commit is contained in:
Preston Rasmussen 2025-07-14 11:15:38 -04:00 committed by GitHub
parent 3200afa363
commit e56ba1a71c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 53 additions and 44 deletions

View file

@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig
logger = logging.getLogger(__name__)
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
DEFAULT_BATCH_SIZE = 100
@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient):
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
# Gemini API has a limit on the number of instances per request
#https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
self.batch_size = 1
elif batch_size is None:
self.batch_size = DEFAULT_BATCH_SIZE
@ -113,32 +113,34 @@ class GeminiEmbedder(EmbedderClient):
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
"""
Create embeddings for a batch of input data using Google's Gemini embedding model.
This method handles batching to respect the Gemini API's limits on the number
of instances that can be processed in a single request.
Args:
input_data_list: A list of strings to create embeddings for.
Returns:
A list of embedding vectors (each vector is a list of floats).
"""
if not input_data_list:
return []
batch_size = self.batch_size
all_embeddings = []
# Process inputs in batches
for i in range(0, len(input_data_list), batch_size):
batch = input_data_list[i:i + batch_size]
batch = input_data_list[i : i + batch_size]
try:
# Generate embeddings for this batch
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
)
if not result.embeddings or len(result.embeddings) == 0:
@ -149,29 +151,33 @@ class GeminiEmbedder(EmbedderClient):
if not embedding.values:
raise ValueError('Empty embedding values returned')
all_embeddings.append(embedding.values)
except Exception as e:
# If batch processing fails, fall back to individual processing
logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}")
logger.warning(
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
)
for item in batch:
try:
# Process each item individually
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
)
if not result.embeddings or len(result.embeddings) == 0:
raise ValueError('No embeddings returned from Gemini API')
if not result.embeddings[0].values:
raise ValueError('Empty embedding values returned')
all_embeddings.append(result.embeddings[0].values)
except Exception as individual_error:
logger.error(f"Failed to embed individual item: {individual_error}")
logger.error(f'Failed to embed individual item: {individual_error}')
raise individual_error
return all_embeddings

View file

@ -172,13 +172,13 @@ class LLMClient(ABC):
"""
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
"""
log = ""
log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n"
log = ''
log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
if output is not None:
if len(output) > 4000:
log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n"
log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
else:
log += f"Raw output: {output}\n"
log += f'Raw output: {output}\n'
else:
log += "No raw output available"
log += 'No raw output available'
return log

View file

@ -219,14 +219,14 @@ class GeminiClient(LLMClient):
array_match = re.search(r'\]\s*$', raw_output)
if array_match:
try:
return json.loads(raw_output[:array_match.end()])
return json.loads(raw_output[: array_match.end()])
except Exception:
pass
# Try to salvage a JSON object
obj_match = re.search(r'\}\s*$', raw_output)
if obj_match:
try:
return json.loads(raw_output[:obj_match.end()])
return json.loads(raw_output[: obj_match.end()])
except Exception:
pass
return None
@ -323,12 +323,14 @@ class GeminiClient(LLMClient):
return validated_model.model_dump()
except Exception as e:
if raw_output:
logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.")
logger.error(
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
)
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
# Try to salvage
salvaged = self.salvage_json(raw_output)
if salvaged is not None:
logger.warning("Salvaged partial JSON from truncated/malformed output.")
logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged
raise Exception(f'Failed to parse structured response: {e}') from e
@ -384,7 +386,11 @@ class GeminiClient(LLMClient):
max_tokens=max_tokens,
model_size=model_size,
)
last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None
last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
else None
)
return response
except RateLimitError as e:
# Rate limit errors should not trigger retries (fail fast)
@ -416,7 +422,7 @@ class GeminiClient(LLMClient):
)
# If we exit the loop without returning, all retries are exhausted
logger.error("🦀 LLM generation failed and retries are exhausted.")
logger.error('🦀 LLM generation failed and retries are exhausted.')
logger.error(self._get_failed_generation_log(messages, last_output))
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
raise last_error or Exception("Max retries exceeded")
raise last_error or Exception('Max retries exceeded')

View file

@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """
"""
ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
RETURN r.uuid AS uuid"""

View file

@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch:
# Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2 # Third call for individual item 2
mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2, # Third call for individual item 2
]
input_batch = ['Input 1', 'Input 2']

View file

@ -273,7 +273,7 @@ class TestGeminiClientGenerateResponse:
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse:
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert (
mock_gemini_client.aio.models.generate_content.call_count
== GeminiClient.MAX_RETRIES
)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
@ -363,7 +360,7 @@ class TestGeminiClientGenerateResponse:
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have exhausted retries due to empty response (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.17.1"
version = "0.17.2"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },