From e56ba1a71c0ec952bce0d1195ae7b1e337cd286d Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 14 Jul 2025 11:15:38 -0400 Subject: [PATCH] save edge update (#721) --- graphiti_core/embedder/gemini.py | 44 +++++++++++-------- graphiti_core/llm_client/client.py | 10 ++--- graphiti_core/llm_client/gemini_client.py | 20 ++++++--- graphiti_core/models/edges/edge_db_queries.py | 6 +-- tests/embedder/test_gemini.py | 6 +-- tests/llm_client/test_gemini_client.py | 9 ++-- uv.lock | 2 +- 7 files changed, 53 insertions(+), 44 deletions(-) diff --git a/graphiti_core/embedder/gemini.py b/graphiti_core/embedder/gemini.py index f144256f..2c16884c 100644 --- a/graphiti_core/embedder/gemini.py +++ b/graphiti_core/embedder/gemini.py @@ -37,7 +37,7 @@ from .client import EmbedderClient, EmbedderConfig logger = logging.getLogger(__name__) -DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005 +DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005 DEFAULT_BATCH_SIZE = 100 @@ -78,7 +78,7 @@ class GeminiEmbedder(EmbedderClient): if batch_size is None and self.config.embedding_model == 'gemini-embedding-001': # Gemini API has a limit on the number of instances per request - #https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api + # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api self.batch_size = 1 elif batch_size is None: self.batch_size = DEFAULT_BATCH_SIZE @@ -113,32 +113,34 @@ class GeminiEmbedder(EmbedderClient): async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: """ Create embeddings for a batch of input data using Google's Gemini embedding model. - + This method handles batching to respect the Gemini API's limits on the number of instances that can be processed in a single request. - + Args: input_data_list: A list of strings to create embeddings for. - + Returns: A list of embedding vectors (each vector is a list of floats). """ if not input_data_list: return [] - + batch_size = self.batch_size all_embeddings = [] - + # Process inputs in batches for i in range(0, len(input_data_list), batch_size): - batch = input_data_list[i:i + batch_size] - + batch = input_data_list[i : i + batch_size] + try: # Generate embeddings for this batch result = await self.client.aio.models.embed_content( model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, contents=batch, # type: ignore[arg-type] # mypy fails on broad union type - config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), + config=types.EmbedContentConfig( + output_dimensionality=self.config.embedding_dim + ), ) if not result.embeddings or len(result.embeddings) == 0: @@ -149,29 +151,33 @@ class GeminiEmbedder(EmbedderClient): if not embedding.values: raise ValueError('Empty embedding values returned') all_embeddings.append(embedding.values) - + except Exception as e: # If batch processing fails, fall back to individual processing - logger.warning(f"Batch embedding failed for batch {i//batch_size + 1}, falling back to individual processing: {e}") - + logger.warning( + f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}' + ) + for item in batch: try: # Process each item individually result = await self.client.aio.models.embed_content( model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, contents=[item], # type: ignore[arg-type] # mypy fails on broad union type - config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), + config=types.EmbedContentConfig( + output_dimensionality=self.config.embedding_dim + ), ) - + if not result.embeddings or len(result.embeddings) == 0: raise ValueError('No embeddings returned from Gemini API') if not result.embeddings[0].values: raise ValueError('Empty embedding values returned') - + all_embeddings.append(result.embeddings[0].values) - + except Exception as individual_error: - logger.error(f"Failed to embed individual item: {individual_error}") + logger.error(f'Failed to embed individual item: {individual_error}') raise individual_error - + return all_embeddings diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 2f64de5a..9f7558c1 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -172,13 +172,13 @@ class LLMClient(ABC): """ Log the full input messages, the raw output (if any), and the exception for debugging failed generations. """ - log = "" - log += f"Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n" + log = '' + log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n' if output is not None: if len(output) > 4000: - log += f"Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n" + log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n' else: - log += f"Raw output: {output}\n" + log += f'Raw output: {output}\n' else: - log += "No raw output available" + log += 'No raw output available' return log diff --git a/graphiti_core/llm_client/gemini_client.py b/graphiti_core/llm_client/gemini_client.py index 70f08eb7..a8422dfd 100644 --- a/graphiti_core/llm_client/gemini_client.py +++ b/graphiti_core/llm_client/gemini_client.py @@ -219,14 +219,14 @@ class GeminiClient(LLMClient): array_match = re.search(r'\]\s*$', raw_output) if array_match: try: - return json.loads(raw_output[:array_match.end()]) + return json.loads(raw_output[: array_match.end()]) except Exception: pass # Try to salvage a JSON object obj_match = re.search(r'\}\s*$', raw_output) if obj_match: try: - return json.loads(raw_output[:obj_match.end()]) + return json.loads(raw_output[: obj_match.end()]) except Exception: pass return None @@ -323,12 +323,14 @@ class GeminiClient(LLMClient): return validated_model.model_dump() except Exception as e: if raw_output: - logger.error("🦀 LLM generation failed parsing as JSON, will try to salvage.") + logger.error( + '🦀 LLM generation failed parsing as JSON, will try to salvage.' + ) logger.error(self._get_failed_generation_log(gemini_messages, raw_output)) # Try to salvage salvaged = self.salvage_json(raw_output) if salvaged is not None: - logger.warning("Salvaged partial JSON from truncated/malformed output.") + logger.warning('Salvaged partial JSON from truncated/malformed output.') return salvaged raise Exception(f'Failed to parse structured response: {e}') from e @@ -384,7 +386,11 @@ class GeminiClient(LLMClient): max_tokens=max_tokens, model_size=model_size, ) - last_output = response.get('content') if isinstance(response, dict) and 'content' in response else None + last_output = ( + response.get('content') + if isinstance(response, dict) and 'content' in response + else None + ) return response except RateLimitError as e: # Rate limit errors should not trigger retries (fail fast) @@ -416,7 +422,7 @@ class GeminiClient(LLMClient): ) # If we exit the loop without returning, all retries are exhausted - logger.error("🦀 LLM generation failed and retries are exhausted.") + logger.error('🦀 LLM generation failed and retries are exhausted.') logger.error(self._get_failed_generation_log(messages, last_output)) logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}') - raise last_error or Exception("Max retries exceeded") + raise last_error or Exception('Max retries exceeded') diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index f3cda2f2..47b6f04f 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -31,9 +31,9 @@ EPISODIC_EDGE_SAVE_BULK = """ """ ENTITY_EDGE_SAVE = """ - MATCH (source:Entity {uuid: $source_uuid}) - MATCH (target:Entity {uuid: $target_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target) SET r = $edge_data WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding) RETURN r.uuid AS uuid""" diff --git a/tests/embedder/test_gemini.py b/tests/embedder/test_gemini.py index c851b5f1..3a0d4101 100644 --- a/tests/embedder/test_gemini.py +++ b/tests/embedder/test_gemini.py @@ -345,9 +345,9 @@ class TestGeminiEmbedderCreateBatch: # Set side_effect for embed_content to control return values for each call mock_gemini_client.aio.models.embed_content.side_effect = [ - mock_batch_response, # First call for the batch - mock_individual_response_1, # Second call for individual item 1 - mock_individual_response_2 # Third call for individual item 2 + mock_batch_response, # First call for the batch + mock_individual_response_1, # Second call for individual item 1 + mock_individual_response_2, # Third call for individual item 2 ] input_batch = ['Input 1', 'Input 2'] diff --git a/tests/llm_client/test_gemini_client.py b/tests/llm_client/test_gemini_client.py index 7ea98f58..263f93c6 100644 --- a/tests/llm_client/test_gemini_client.py +++ b/tests/llm_client/test_gemini_client.py @@ -273,7 +273,7 @@ class TestGeminiClientGenerateResponse: messages = [Message(role='user', content='Test message')] with pytest.raises(Exception): # noqa: B017 await gemini_client.generate_response(messages, response_model=ResponseModel) - + # Should have called generate_content MAX_RETRIES times (2 attempts total) assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES @@ -344,10 +344,7 @@ class TestGeminiClientGenerateResponse: await gemini_client.generate_response(messages, response_model=ResponseModel) # Should have called generate_content MAX_RETRIES times (2 attempts total) - assert ( - mock_gemini_client.aio.models.generate_content.call_count - == GeminiClient.MAX_RETRIES - ) + assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES @pytest.mark.asyncio async def test_empty_response_handling(self, gemini_client, mock_gemini_client): @@ -363,7 +360,7 @@ class TestGeminiClientGenerateResponse: messages = [Message(role='user', content='Test message')] with pytest.raises(Exception): # noqa: B017 await gemini_client.generate_response(messages, response_model=ResponseModel) - + # Should have exhausted retries due to empty response (2 attempts total) assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES diff --git a/uv.lock b/uv.lock index 94d63db4..77ec9b23 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.17.1" +version = "0.17.2" source = { editable = "." } dependencies = [ { name = "diskcache" },