save edge update (#721)
This commit is contained in:
parent
3200afa363
commit
e56ba1a71c
7 changed files with 53 additions and 44 deletions
|
|
@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
|
||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
|
||||
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
|
||||
|
|
@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient):
|
|||
|
||||
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
|
||||
# Gemini API has a limit on the number of instances per request
|
||||
#https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
|
||||
self.batch_size = 1
|
||||
elif batch_size is None:
|
||||
self.batch_size = DEFAULT_BATCH_SIZE
|
||||
|
|
@ -113,32 +113,34 @@ class GeminiEmbedder(EmbedderClient):
|
|||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Create embeddings for a batch of input data using Google's Gemini embedding model.
|
||||
|
||||
|
||||
This method handles batching to respect the Gemini API's limits on the number
|
||||
of instances that can be processed in a single request.
|
||||
|
||||
|
||||
Args:
|
||||
input_data_list: A list of strings to create embeddings for.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of embedding vectors (each vector is a list of floats).
|
||||
"""
|
||||
if not input_data_list:
|
||||
return []
|
||||
|
||||
|
||||
batch_size = self.batch_size
|
||||
all_embeddings = []
|
||||
|
||||
|
||||
# Process inputs in batches
|
||||
for i in range(0, len(input_data_list), batch_size):
|
||||
batch = input_data_list[i:i + batch_size]
|
||||
|
||||
batch = input_data_list[i : i + batch_size]
|
||||
|
||||
try:
|
||||
# Generate embeddings for this batch
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
|
||||
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
||||
config=types.EmbedContentConfig(
|
||||
output_dimensionality=self.config.embedding_dim
|
||||
),
|
||||
)
|
||||
|
||||
if not result.embeddings or len(result.embeddings) == 0:
|
||||
|
|
@ -149,29 +151,33 @@ class GeminiEmbedder(EmbedderClient):
|
|||
if not embedding.values:
|
||||
raise ValueError('Empty embedding values returned')
|
||||
all_embeddings.append(embedding.values)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If batch processing fails, fall back to individual processing
|
||||
logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}")
|
||||
|
||||
logger.warning(
|
||||
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
|
||||
)
|
||||
|
||||
for item in batch:
|
||||
try:
|
||||
# Process each item individually
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
|
||||
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
||||
config=types.EmbedContentConfig(
|
||||
output_dimensionality=self.config.embedding_dim
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if not result.embeddings or len(result.embeddings) == 0:
|
||||
raise ValueError('No embeddings returned from Gemini API')
|
||||
if not result.embeddings[0].values:
|
||||
raise ValueError('Empty embedding values returned')
|
||||
|
||||
|
||||
all_embeddings.append(result.embeddings[0].values)
|
||||
|
||||
|
||||
except Exception as individual_error:
|
||||
logger.error(f"Failed to embed individual item: {individual_error}")
|
||||
logger.error(f'Failed to embed individual item: {individual_error}')
|
||||
raise individual_error
|
||||
|
||||
|
||||
return all_embeddings
|
||||
|
|
|
|||
|
|
@ -172,13 +172,13 @@ class LLMClient(ABC):
|
|||
"""
|
||||
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
|
||||
"""
|
||||
log = ""
|
||||
log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n"
|
||||
log = ''
|
||||
log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
|
||||
if output is not None:
|
||||
if len(output) > 4000:
|
||||
log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n"
|
||||
log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
|
||||
else:
|
||||
log += f"Raw output: {output}\n"
|
||||
log += f'Raw output: {output}\n'
|
||||
else:
|
||||
log += "No raw output available"
|
||||
log += 'No raw output available'
|
||||
return log
|
||||
|
|
|
|||
|
|
@ -219,14 +219,14 @@ class GeminiClient(LLMClient):
|
|||
array_match = re.search(r'\]\s*$', raw_output)
|
||||
if array_match:
|
||||
try:
|
||||
return json.loads(raw_output[:array_match.end()])
|
||||
return json.loads(raw_output[: array_match.end()])
|
||||
except Exception:
|
||||
pass
|
||||
# Try to salvage a JSON object
|
||||
obj_match = re.search(r'\}\s*$', raw_output)
|
||||
if obj_match:
|
||||
try:
|
||||
return json.loads(raw_output[:obj_match.end()])
|
||||
return json.loads(raw_output[: obj_match.end()])
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
|
@ -323,12 +323,14 @@ class GeminiClient(LLMClient):
|
|||
return validated_model.model_dump()
|
||||
except Exception as e:
|
||||
if raw_output:
|
||||
logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.")
|
||||
logger.error(
|
||||
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
|
||||
)
|
||||
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
|
||||
# Try to salvage
|
||||
salvaged = self.salvage_json(raw_output)
|
||||
if salvaged is not None:
|
||||
logger.warning("Salvaged partial JSON from truncated/malformed output.")
|
||||
logger.warning('Salvaged partial JSON from truncated/malformed output.')
|
||||
return salvaged
|
||||
raise Exception(f'Failed to parse structured response: {e}') from e
|
||||
|
||||
|
|
@ -384,7 +386,11 @@ class GeminiClient(LLMClient):
|
|||
max_tokens=max_tokens,
|
||||
model_size=model_size,
|
||||
)
|
||||
last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None
|
||||
last_output = (
|
||||
response.get('content')
|
||||
if isinstance(response, dict) and 'content' in response
|
||||
else None
|
||||
)
|
||||
return response
|
||||
except RateLimitError as e:
|
||||
# Rate limit errors should not trigger retries (fail fast)
|
||||
|
|
@ -416,7 +422,7 @@ class GeminiClient(LLMClient):
|
|||
)
|
||||
|
||||
# If we exit the loop without returning, all retries are exhausted
|
||||
logger.error("🦀 LLM generation failed and retries are exhausted.")
|
||||
logger.error('🦀 LLM generation failed and retries are exhausted.')
|
||||
logger.error(self._get_failed_generation_log(messages, last_output))
|
||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
|
||||
raise last_error or Exception("Max retries exceeded")
|
||||
raise last_error or Exception('Max retries exceeded')
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """
|
|||
"""
|
||||
|
||||
ENTITY_EDGE_SAVE = """
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET r = $edge_data
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
|
|
|||
|
|
@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch:
|
|||
|
||||
# Set side_effect for embed_content to control return values for each call
|
||||
mock_gemini_client.aio.models.embed_content.side_effect = [
|
||||
mock_batch_response, # First call for the batch
|
||||
mock_individual_response_1, # Second call for individual item 1
|
||||
mock_individual_response_2 # Third call for individual item 2
|
||||
mock_batch_response, # First call for the batch
|
||||
mock_individual_response_1, # Second call for individual item 1
|
||||
mock_individual_response_2, # Third call for individual item 2
|
||||
]
|
||||
|
||||
input_batch = ['Input 1', 'Input 2']
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ class TestGeminiClientGenerateResponse:
|
|||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
|
||||
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||
|
||||
|
|
@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse:
|
|||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
||||
assert (
|
||||
mock_gemini_client.aio.models.generate_content.call_count
|
||||
== GeminiClient.MAX_RETRIES
|
||||
)
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
|
||||
|
|
@ -363,7 +360,7 @@ class TestGeminiClientGenerateResponse:
|
|||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
|
||||
# Should have exhausted retries due to empty response (2 attempts total)
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.17.1"
|
||||
version = "0.17.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue