save edge update (#721)

This commit is contained in:
Preston Rasmussen 2025-07-14 11:15:38 -04:00 committed by GitHub
parent 3200afa363
commit e56ba1a71c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 53 additions and 44 deletions

View file

@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005 DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
DEFAULT_BATCH_SIZE = 100 DEFAULT_BATCH_SIZE = 100
@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient):
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001': if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
# Gemini API has a limit on the number of instances per request # Gemini API has a limit on the number of instances per request
#https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
self.batch_size = 1 self.batch_size = 1
elif batch_size is None: elif batch_size is None:
self.batch_size = DEFAULT_BATCH_SIZE self.batch_size = DEFAULT_BATCH_SIZE
@ -131,14 +131,16 @@ class GeminiEmbedder(EmbedderClient):
# Process inputs in batches # Process inputs in batches
for i in range(0, len(input_data_list), batch_size): for i in range(0, len(input_data_list), batch_size):
batch = input_data_list[i:i + batch_size] batch = input_data_list[i : i + batch_size]
try: try:
# Generate embeddings for this batch # Generate embeddings for this batch
result = await self.client.aio.models.embed_content( result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
) )
if not result.embeddings or len(result.embeddings) == 0: if not result.embeddings or len(result.embeddings) == 0:
@ -152,7 +154,9 @@ class GeminiEmbedder(EmbedderClient):
except Exception as e: except Exception as e:
# If batch processing fails, fall back to individual processing # If batch processing fails, fall back to individual processing
logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}") logger.warning(
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
)
for item in batch: for item in batch:
try: try:
@ -160,7 +164,9 @@ class GeminiEmbedder(EmbedderClient):
result = await self.client.aio.models.embed_content( result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
) )
if not result.embeddings or len(result.embeddings) == 0: if not result.embeddings or len(result.embeddings) == 0:
@ -171,7 +177,7 @@ class GeminiEmbedder(EmbedderClient):
all_embeddings.append(result.embeddings[0].values) all_embeddings.append(result.embeddings[0].values)
except Exception as individual_error: except Exception as individual_error:
logger.error(f"Failed to embed individual item: {individual_error}") logger.error(f'Failed to embed individual item: {individual_error}')
raise individual_error raise individual_error
return all_embeddings return all_embeddings

View file

@ -172,13 +172,13 @@ class LLMClient(ABC):
""" """
Log the full input messages, the raw output (if any), and the exception for debugging failed generations. Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
""" """
log = "" log = ''
log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n" log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
if output is not None: if output is not None:
if len(output) > 4000: if len(output) > 4000:
log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n" log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
else: else:
log += f"Raw output: {output}\n" log += f'Raw output: {output}\n'
else: else:
log += "No raw output available" log += 'No raw output available'
return log return log

View file

@ -219,14 +219,14 @@ class GeminiClient(LLMClient):
array_match = re.search(r'\]\s*$', raw_output) array_match = re.search(r'\]\s*$', raw_output)
if array_match: if array_match:
try: try:
return json.loads(raw_output[:array_match.end()]) return json.loads(raw_output[: array_match.end()])
except Exception: except Exception:
pass pass
# Try to salvage a JSON object # Try to salvage a JSON object
obj_match = re.search(r'\}\s*$', raw_output) obj_match = re.search(r'\}\s*$', raw_output)
if obj_match: if obj_match:
try: try:
return json.loads(raw_output[:obj_match.end()]) return json.loads(raw_output[: obj_match.end()])
except Exception: except Exception:
pass pass
return None return None
@ -323,12 +323,14 @@ class GeminiClient(LLMClient):
return validated_model.model_dump() return validated_model.model_dump()
except Exception as e: except Exception as e:
if raw_output: if raw_output:
logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.") logger.error(
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
)
logger.error(self._get_failed_generation_log(gemini_messages, raw_output)) logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
# Try to salvage # Try to salvage
salvaged = self.salvage_json(raw_output) salvaged = self.salvage_json(raw_output)
if salvaged is not None: if salvaged is not None:
logger.warning("Salvaged partial JSON from truncated/malformed output.") logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged return salvaged
raise Exception(f'Failed to parse structured response: {e}') from e raise Exception(f'Failed to parse structured response: {e}') from e
@ -384,7 +386,11 @@ class GeminiClient(LLMClient):
max_tokens=max_tokens, max_tokens=max_tokens,
model_size=model_size, model_size=model_size,
) )
last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
else None
)
return response return response
except RateLimitError as e: except RateLimitError as e:
# Rate limit errors should not trigger retries (fail fast) # Rate limit errors should not trigger retries (fail fast)
@ -416,7 +422,7 @@ class GeminiClient(LLMClient):
) )
# If we exit the loop without returning, all retries are exhausted # If we exit the loop without returning, all retries are exhausted
logger.error("🦀 LLM generation failed and retries are exhausted.") logger.error('🦀 LLM generation failed and retries are exhausted.')
logger.error(self._get_failed_generation_log(messages, last_output)) logger.error(self._get_failed_generation_log(messages, last_output))
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}') logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
raise last_error or Exception("Max retries exceeded") raise last_error or Exception('Max retries exceeded')

View file

@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """
""" """
ENTITY_EDGE_SAVE = """ ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $target_uuid}) MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET r = $edge_data SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding) WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
RETURN r.uuid AS uuid""" RETURN r.uuid AS uuid"""

View file

@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch:
# Set side_effect for embed_content to control return values for each call # Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [ mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1 mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2 # Third call for individual item 2 mock_individual_response_2, # Third call for individual item 2
] ]
input_batch = ['Input 1', 'Input 2'] input_batch = ['Input 1', 'Input 2']

View file

@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse:
await gemini_client.generate_response(messages, response_model=ResponseModel) await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total) # Should have called generate_content MAX_RETRIES times (2 attempts total)
assert ( assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
mock_gemini_client.aio.models.generate_content.call_count
== GeminiClient.MAX_RETRIES
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client): async def test_empty_response_handling(self, gemini_client, mock_gemini_client):

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.17.1" version = "0.17.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },