save edge update (#721)
This commit is contained in:
parent
3200afa363
commit
e56ba1a71c
7 changed files with 53 additions and 44 deletions
|
|
@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
|
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
|
||||||
|
|
||||||
DEFAULT_BATCH_SIZE = 100
|
DEFAULT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
|
|
||||||
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
|
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
|
||||||
# Gemini API has a limit on the number of instances per request
|
# Gemini API has a limit on the number of instances per request
|
||||||
#https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
|
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
elif batch_size is None:
|
elif batch_size is None:
|
||||||
self.batch_size = DEFAULT_BATCH_SIZE
|
self.batch_size = DEFAULT_BATCH_SIZE
|
||||||
|
|
@ -113,32 +113,34 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Create embeddings for a batch of input data using Google's Gemini embedding model.
|
Create embeddings for a batch of input data using Google's Gemini embedding model.
|
||||||
|
|
||||||
This method handles batching to respect the Gemini API's limits on the number
|
This method handles batching to respect the Gemini API's limits on the number
|
||||||
of instances that can be processed in a single request.
|
of instances that can be processed in a single request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data_list: A list of strings to create embeddings for.
|
input_data_list: A list of strings to create embeddings for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of embedding vectors (each vector is a list of floats).
|
A list of embedding vectors (each vector is a list of floats).
|
||||||
"""
|
"""
|
||||||
if not input_data_list:
|
if not input_data_list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
|
||||||
# Process inputs in batches
|
# Process inputs in batches
|
||||||
for i in range(0, len(input_data_list), batch_size):
|
for i in range(0, len(input_data_list), batch_size):
|
||||||
batch = input_data_list[i:i + batch_size]
|
batch = input_data_list[i : i + batch_size]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate embeddings for this batch
|
# Generate embeddings for this batch
|
||||||
result = await self.client.aio.models.embed_content(
|
result = await self.client.aio.models.embed_content(
|
||||||
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||||
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
|
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
|
||||||
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
config=types.EmbedContentConfig(
|
||||||
|
output_dimensionality=self.config.embedding_dim
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.embeddings or len(result.embeddings) == 0:
|
if not result.embeddings or len(result.embeddings) == 0:
|
||||||
|
|
@ -149,29 +151,33 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
if not embedding.values:
|
if not embedding.values:
|
||||||
raise ValueError('Empty embedding values returned')
|
raise ValueError('Empty embedding values returned')
|
||||||
all_embeddings.append(embedding.values)
|
all_embeddings.append(embedding.values)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If batch processing fails, fall back to individual processing
|
# If batch processing fails, fall back to individual processing
|
||||||
logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}")
|
logger.warning(
|
||||||
|
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
|
||||||
|
)
|
||||||
|
|
||||||
for item in batch:
|
for item in batch:
|
||||||
try:
|
try:
|
||||||
# Process each item individually
|
# Process each item individually
|
||||||
result = await self.client.aio.models.embed_content(
|
result = await self.client.aio.models.embed_content(
|
||||||
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||||
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
|
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
|
||||||
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
config=types.EmbedContentConfig(
|
||||||
|
output_dimensionality=self.config.embedding_dim
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.embeddings or len(result.embeddings) == 0:
|
if not result.embeddings or len(result.embeddings) == 0:
|
||||||
raise ValueError('No embeddings returned from Gemini API')
|
raise ValueError('No embeddings returned from Gemini API')
|
||||||
if not result.embeddings[0].values:
|
if not result.embeddings[0].values:
|
||||||
raise ValueError('Empty embedding values returned')
|
raise ValueError('Empty embedding values returned')
|
||||||
|
|
||||||
all_embeddings.append(result.embeddings[0].values)
|
all_embeddings.append(result.embeddings[0].values)
|
||||||
|
|
||||||
except Exception as individual_error:
|
except Exception as individual_error:
|
||||||
logger.error(f"Failed to embed individual item: {individual_error}")
|
logger.error(f'Failed to embed individual item: {individual_error}')
|
||||||
raise individual_error
|
raise individual_error
|
||||||
|
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
|
||||||
|
|
@ -172,13 +172,13 @@ class LLMClient(ABC):
|
||||||
"""
|
"""
|
||||||
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
|
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
|
||||||
"""
|
"""
|
||||||
log = ""
|
log = ''
|
||||||
log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n"
|
log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
|
||||||
if output is not None:
|
if output is not None:
|
||||||
if len(output) > 4000:
|
if len(output) > 4000:
|
||||||
log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n"
|
log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
|
||||||
else:
|
else:
|
||||||
log += f"Raw output: {output}\n"
|
log += f'Raw output: {output}\n'
|
||||||
else:
|
else:
|
||||||
log += "No raw output available"
|
log += 'No raw output available'
|
||||||
return log
|
return log
|
||||||
|
|
|
||||||
|
|
@ -219,14 +219,14 @@ class GeminiClient(LLMClient):
|
||||||
array_match = re.search(r'\]\s*$', raw_output)
|
array_match = re.search(r'\]\s*$', raw_output)
|
||||||
if array_match:
|
if array_match:
|
||||||
try:
|
try:
|
||||||
return json.loads(raw_output[:array_match.end()])
|
return json.loads(raw_output[: array_match.end()])
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Try to salvage a JSON object
|
# Try to salvage a JSON object
|
||||||
obj_match = re.search(r'\}\s*$', raw_output)
|
obj_match = re.search(r'\}\s*$', raw_output)
|
||||||
if obj_match:
|
if obj_match:
|
||||||
try:
|
try:
|
||||||
return json.loads(raw_output[:obj_match.end()])
|
return json.loads(raw_output[: obj_match.end()])
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
@ -323,12 +323,14 @@ class GeminiClient(LLMClient):
|
||||||
return validated_model.model_dump()
|
return validated_model.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if raw_output:
|
if raw_output:
|
||||||
logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.")
|
logger.error(
|
||||||
|
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
|
||||||
|
)
|
||||||
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
|
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
|
||||||
# Try to salvage
|
# Try to salvage
|
||||||
salvaged = self.salvage_json(raw_output)
|
salvaged = self.salvage_json(raw_output)
|
||||||
if salvaged is not None:
|
if salvaged is not None:
|
||||||
logger.warning("Salvaged partial JSON from truncated/malformed output.")
|
logger.warning('Salvaged partial JSON from truncated/malformed output.')
|
||||||
return salvaged
|
return salvaged
|
||||||
raise Exception(f'Failed to parse structured response: {e}') from e
|
raise Exception(f'Failed to parse structured response: {e}') from e
|
||||||
|
|
||||||
|
|
@ -384,7 +386,11 @@ class GeminiClient(LLMClient):
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
model_size=model_size,
|
model_size=model_size,
|
||||||
)
|
)
|
||||||
last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None
|
last_output = (
|
||||||
|
response.get('content')
|
||||||
|
if isinstance(response, dict) and 'content' in response
|
||||||
|
else None
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
# Rate limit errors should not trigger retries (fail fast)
|
# Rate limit errors should not trigger retries (fail fast)
|
||||||
|
|
@ -416,7 +422,7 @@ class GeminiClient(LLMClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we exit the loop without returning, all retries are exhausted
|
# If we exit the loop without returning, all retries are exhausted
|
||||||
logger.error("🦀 LLM generation failed and retries are exhausted.")
|
logger.error('🦀 LLM generation failed and retries are exhausted.')
|
||||||
logger.error(self._get_failed_generation_log(messages, last_output))
|
logger.error(self._get_failed_generation_log(messages, last_output))
|
||||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
|
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
|
||||||
raise last_error or Exception("Max retries exceeded")
|
raise last_error or Exception('Max retries exceeded')
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENTITY_EDGE_SAVE = """
|
ENTITY_EDGE_SAVE = """
|
||||||
MATCH (source:Entity {uuid: $source_uuid})
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
MATCH (target:Entity {uuid: $target_uuid})
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
SET r = $edge_data
|
SET r = $edge_data
|
||||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
|
||||||
RETURN r.uuid AS uuid"""
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
|
||||||
|
|
@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch:
|
||||||
|
|
||||||
# Set side_effect for embed_content to control return values for each call
|
# Set side_effect for embed_content to control return values for each call
|
||||||
mock_gemini_client.aio.models.embed_content.side_effect = [
|
mock_gemini_client.aio.models.embed_content.side_effect = [
|
||||||
mock_batch_response, # First call for the batch
|
mock_batch_response, # First call for the batch
|
||||||
mock_individual_response_1, # Second call for individual item 1
|
mock_individual_response_1, # Second call for individual item 1
|
||||||
mock_individual_response_2 # Third call for individual item 2
|
mock_individual_response_2, # Third call for individual item 2
|
||||||
]
|
]
|
||||||
|
|
||||||
input_batch = ['Input 1', 'Input 2']
|
input_batch = ['Input 1', 'Input 2']
|
||||||
|
|
|
||||||
|
|
@ -273,7 +273,7 @@ class TestGeminiClientGenerateResponse:
|
||||||
messages = [Message(role='user', content='Test message')]
|
messages = [Message(role='user', content='Test message')]
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||||
|
|
||||||
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
||||||
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||||
|
|
||||||
|
|
@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse:
|
||||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||||
|
|
||||||
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
# Should have called generate_content MAX_RETRIES times (2 attempts total)
|
||||||
assert (
|
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||||
mock_gemini_client.aio.models.generate_content.call_count
|
|
||||||
== GeminiClient.MAX_RETRIES
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
|
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
|
||||||
|
|
@ -363,7 +360,7 @@ class TestGeminiClientGenerateResponse:
|
||||||
messages = [Message(role='user', content='Test message')]
|
messages = [Message(role='user', content='Test message')]
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||||
|
|
||||||
# Should have exhausted retries due to empty response (2 attempts total)
|
# Should have exhausted retries due to empty response (2 attempts total)
|
||||||
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
|
||||||
|
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.17.1"
|
version = "0.17.2"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue