save edge update (#721)

This commit is contained in:
Preston Rasmussen 2025-07-14 11:15:38 -04:00 committed by GitHub
parent 3200afa363
commit e56ba1a71c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 53 additions and 44 deletions

View file

@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005 DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
DEFAULT_BATCH_SIZE = 100 DEFAULT_BATCH_SIZE = 100
@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient):
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001': if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
# Gemini API has a limit on the number of instances per request # Gemini API has a limit on the number of instances per request
#https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
self.batch_size = 1 self.batch_size = 1
elif batch_size is None: elif batch_size is None:
self.batch_size = DEFAULT_BATCH_SIZE self.batch_size = DEFAULT_BATCH_SIZE
@ -113,32 +113,34 @@ class GeminiEmbedder(EmbedderClient):
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
""" """
Create embeddings for a batch of input data using Google's Gemini embedding model. Create embeddings for a batch of input data using Google's Gemini embedding model.
This method handles batching to respect the Gemini API's limits on the number This method handles batching to respect the Gemini API's limits on the number
of instances that can be processed in a single request. of instances that can be processed in a single request.
Args: Args:
input_data_list: A list of strings to create embeddings for. input_data_list: A list of strings to create embeddings for.
Returns: Returns:
A list of embedding vectors (each vector is a list of floats). A list of embedding vectors (each vector is a list of floats).
""" """
if not input_data_list: if not input_data_list:
return [] return []
batch_size = self.batch_size batch_size = self.batch_size
all_embeddings = [] all_embeddings = []
# Process inputs in batches # Process inputs in batches
for i in range(0, len(input_data_list), batch_size): for i in range(0, len(input_data_list), batch_size):
batch = input_data_list[i:i + batch_size] batch = input_data_list[i : i + batch_size]
try: try:
# Generate embeddings for this batch # Generate embeddings for this batch
result = await self.client.aio.models.embed_content( result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
) )
if not result.embeddings or len(result.embeddings) == 0: if not result.embeddings or len(result.embeddings) == 0:
@ -149,29 +151,33 @@ class GeminiEmbedder(EmbedderClient):
if not embedding.values: if not embedding.values:
raise ValueError('Empty embedding values returned') raise ValueError('Empty embedding values returned')
all_embeddings.append(embedding.values) all_embeddings.append(embedding.values)
except Exception as e: except Exception as e:
# If batch processing fails, fall back to individual processing # If batch processing fails, fall back to individual processing
logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}") logger.warning(
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
)
for item in batch: for item in batch:
try: try:
# Process each item individually # Process each item individually
result = await self.client.aio.models.embed_content( result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
) )
if not result.embeddings or len(result.embeddings) == 0: if not result.embeddings or len(result.embeddings) == 0:
raise ValueError('No embeddings returned from Gemini API') raise ValueError('No embeddings returned from Gemini API')
if not result.embeddings[0].values: if not result.embeddings[0].values:
raise ValueError('Empty embedding values returned') raise ValueError('Empty embedding values returned')
all_embeddings.append(result.embeddings[0].values) all_embeddings.append(result.embeddings[0].values)
except Exception as individual_error: except Exception as individual_error:
logger.error(f"Failed to embed individual item: {individual_error}") logger.error(f'Failed to embed individual item: {individual_error}')
raise individual_error raise individual_error
return all_embeddings return all_embeddings

View file

@ -172,13 +172,13 @@ class LLMClient(ABC):
""" """
Log the full input messages, the raw output (if any), and the exception for debugging failed generations. Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
""" """
log = "" log = ''
log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n" log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
if output is not None: if output is not None:
if len(output) > 4000: if len(output) > 4000:
log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n" log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
else: else:
log += f"Raw output: {output}\n" log += f'Raw output: {output}\n'
else: else:
log += "No raw output available" log += 'No raw output available'
return log return log

View file

@ -219,14 +219,14 @@ class GeminiClient(LLMClient):
array_match = re.search(r'\]\s*$', raw_output) array_match = re.search(r'\]\s*$', raw_output)
if array_match: if array_match:
try: try:
return json.loads(raw_output[:array_match.end()]) return json.loads(raw_output[: array_match.end()])
except Exception: except Exception:
pass pass
# Try to salvage a JSON object # Try to salvage a JSON object
obj_match = re.search(r'\}\s*$', raw_output) obj_match = re.search(r'\}\s*$', raw_output)
if obj_match: if obj_match:
try: try:
return json.loads(raw_output[:obj_match.end()]) return json.loads(raw_output[: obj_match.end()])
except Exception: except Exception:
pass pass
return None return None
@ -323,12 +323,14 @@ class GeminiClient(LLMClient):
return validated_model.model_dump() return validated_model.model_dump()
except Exception as e: except Exception as e:
if raw_output: if raw_output:
logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.") logger.error(
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
)
logger.error(self._get_failed_generation_log(gemini_messages, raw_output)) logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
# Try to salvage # Try to salvage
salvaged = self.salvage_json(raw_output) salvaged = self.salvage_json(raw_output)
if salvaged is not None: if salvaged is not None:
logger.warning("Salvaged partial JSON from truncated/malformed output.") logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged return salvaged
raise Exception(f'Failed to parse structured response: {e}') from e raise Exception(f'Failed to parse structured response: {e}') from e
@ -384,7 +386,11 @@ class GeminiClient(LLMClient):
max_tokens=max_tokens, max_tokens=max_tokens,
model_size=model_size, model_size=model_size,
) )
last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
else None
)
return response return response
except RateLimitError as e: except RateLimitError as e:
# Rate limit errors should not trigger retries (fail fast) # Rate limit errors should not trigger retries (fail fast)
@ -416,7 +422,7 @@ class GeminiClient(LLMClient):
) )
# If we exit the loop without returning, all retries are exhausted # If we exit the loop without returning, all retries are exhausted
logger.error("🦀 LLM generation failed and retries are exhausted.") logger.error('🦀 LLM generation failed and retries are exhausted.')
logger.error(self._get_failed_generation_log(messages, last_output)) logger.error(self._get_failed_generation_log(messages, last_output))
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}') logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
raise last_error or Exception("Max retries exceeded") raise last_error or Exception('Max retries exceeded')

View file

@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """
""" """
ENTITY_EDGE_SAVE = """ ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $target_uuid}) MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET r = $edge_data SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding) WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
RETURN r.uuid AS uuid""" RETURN r.uuid AS uuid"""

View file

@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch:
# Set side_effect for embed_content to control return values for each call # Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [ mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1 mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2 # Third call for individual item 2 mock_individual_response_2, # Third call for individual item 2
] ]
input_batch = ['Input 1', 'Input 2'] input_batch = ['Input 1', 'Input 2']

View file

@ -273,7 +273,7 @@ class TestGeminiClientGenerateResponse:
messages = [Message(role='user', content='Test message')] messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel) await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total) # Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse:
await gemini_client.generate_response(messages, response_model=ResponseModel) await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total) # Should have called generate_content MAX_RETRIES times (2 attempts total)
assert ( assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
mock_gemini_client.aio.models.generate_content.call_count
== GeminiClient.MAX_RETRIES
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client): async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
@ -363,7 +360,7 @@ class TestGeminiClientGenerateResponse:
messages = [Message(role='user', content='Test message')] messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel) await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have exhausted retries due to empty response (2 attempts total) # Should have exhausted retries due to empty response (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.17.1" version = "0.17.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },