Gemini support (#324)

* first cut

* Update dependencies and enhance README for optional LLM providers

- Bump aiohttp version from 3.11.14 to 3.11.16
- Update yarl version from 1.18.3 to 1.19.0
- Modify pyproject.toml to include optional extras for Anthropic, Groq, and Google Gemini
- Revise README.md to reflect new optional LLM provider installation instructions and clarify API key requirements

* Remove deprecated packages from poetry.lock and update content hash

- Removed cachetools, google-auth, google-genai, pyasn1, pyasn1-modules, rsa, and websockets from the lock file.
- Added new extras for anthropic, google-genai, and groq.
- Updated content hash to reflect changes.

* Refactor import paths for GeminiClient in README and __init__.py

- Updated import statement in README.md to reflect the new module structure for GeminiClient.
- Removed GeminiClient from the __all__ list in __init__.py as it is no longer directly imported.

* Refactor import paths for GeminiEmbedder in README and __init__.py

- Updated import statement in README.md to reflect the new module structure for GeminiEmbedder.
- Removed GeminiEmbedder and GeminiEmbedderConfig from the __all__ list in __init__.py as they are no longer directly imported.
This commit is contained in:
Daniel Chalef 2025-04-06 09:27:04 -07:00 committed by GitHub
parent d3c83adb04
commit 9e78890f2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1221 additions and 962 deletions

View file

@ -35,8 +35,8 @@ Use Graphiti to:
<br />
A knowledge graph is a network of interconnected facts, such as _“Kendra loves Adidas shoes.”_ Each fact is a “triplet” represented by two entities, or
nodes (_”Kendra”_, _“Adidas shoes”_), and their relationship, or edge (_”loves”_). Knowledge Graphs have been explored
A knowledge graph is a network of interconnected facts, such as _"Kendra loves Adidas shoes."_ Each fact is a "triplet" represented by two entities, or
nodes ("Kendra", "Adidas shoes"), and their relationship, or edge ("loves"). Knowledge Graphs have been explored
extensively for information retrieval. What makes Graphiti unique is its ability to autonomously build a knowledge graph
while handling changing relationships and maintaining historical context.
@ -96,7 +96,7 @@ Requirements:
Optional:
- Anthropic or Groq API key (for alternative LLM providers)
- Google Gemini, Anthropic, or Groq API key (for alternative LLM providers)
> [!TIP]
> The simplest way to install Neo4j is via [Neo4j Desktop](https://neo4j.com/download/). It provides a user-friendly
@ -112,6 +112,22 @@ or
poetry add graphiti-core
```
You can also install optional LLM providers as extras:
```bash
# Install with Anthropic support
pip install graphiti-core[anthropic]
# Install with Groq support
pip install graphiti-core[groq]
# Install with Google Gemini support
pip install graphiti-core[google-genai]
# Install with multiple providers
pip install graphiti-core[anthropic,groq,google-genai]
```
## Quick Start
> [!IMPORTANT]
@ -211,6 +227,42 @@ graphiti = Graphiti(
Make sure to replace the placeholder values with your actual Azure OpenAI credentials and specify the correct embedding model name that's deployed in your Azure OpenAI service.
## Using Graphiti with Google Gemini
Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key.
```python
from graphiti_core import Graphiti
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
# Google API key configuration
api_key = "<your-google-api-key>"
# Initialize Graphiti with Gemini clients
graphiti = Graphiti(
"bolt://localhost:7687",
"neo4j",
"password",
llm_client=GeminiClient(
config=LLMConfig(
api_key=api_key,
model="gemini-2.0-flash"
)
),
embedder=GeminiEmbedder(
config=GeminiEmbedderConfig(
api_key=api_key,
embedding_model="embedding-001"
)
)
)
# Now you can use Graphiti with Google Gemini
```
Make sure to replace the placeholder value with your actual Google API key. You can find more details in the example file at `examples/gemini_example.py`.
## Documentation
- [Guides and API documentation](https://help.getzep.com/graphiti).

View file

@ -1,4 +1,8 @@
from .client import EmbedderClient
from .openai import OpenAIEmbedder, OpenAIEmbedderConfig
__all__ = ['EmbedderClient', 'OpenAIEmbedder', 'OpenAIEmbedderConfig']
__all__ = [
'EmbedderClient',
'OpenAIEmbedder',
'OpenAIEmbedderConfig',
]

View file

@ -0,0 +1,68 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections.abc import Iterable
from google import genai # type: ignore
from google.genai import types # type: ignore
from pydantic import Field
from .client import EmbedderClient, EmbedderConfig
DEFAULT_EMBEDDING_MODEL = 'embedding-001'
class GeminiEmbedderConfig(EmbedderConfig):
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
api_key: str | None = None
class GeminiEmbedder(EmbedderClient):
"""
Google Gemini Embedder Client
"""
def __init__(self, config: GeminiEmbedderConfig | None = None):
if config is None:
config = GeminiEmbedderConfig()
self.config = config
# Configure the Gemini API
self.client = genai.Client(
api_key=config.api_key,
)
async def create(
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
"""
Create embeddings for the given input data using Google's Gemini embedding model.
Args:
input_data: The input data to create embeddings for. Can be a string, list of strings,
or an iterable of integers or iterables of integers.
Returns:
A list of floats representing the embedding vector.
"""
# Generate embeddings
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[input_data],
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
)
return result.embeddings[0].values

View file

@ -71,10 +71,7 @@ from graphiti_core.utils.maintenance.graph_data_operations import (
build_indices_and_constraints,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types

View file

@ -0,0 +1,186 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from google import genai # type: ignore
from google.genai import types # type: ignore
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.0-flash'
class GeminiClient(LLMClient):
"""
GeminiClient is a client class for interacting with Google's Gemini language models.
This class extends the LLMClient and provides methods to initialize the client
and generate responses from the Gemini language model.
Attributes:
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False):
Initializes the GeminiClient with the provided configuration and cache setting.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
"""
Initialize the GeminiClient with the provided configuration and cache setting.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
"""
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.model = config.model
# Configure the Gemini API
self.client = genai.Client(
api_key=config.api_key,
)
self.max_tokens = max_tokens
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
Returns:
dict[str, typing.Any]: The response from the language model.
Raises:
RateLimitError: If the API rate limit is exceeded.
RefusalError: If the content is blocked by the model.
Exception: If there is an error generating the response.
"""
try:
gemini_messages: list[types.Content] = []
# If a response model is provided, add schema for structured output
system_prompt = ''
if response_model is not None:
# Get the schema from the Pydantic model
pydantic_schema = response_model.model_json_schema()
# Create instruction to output in the desired JSON format
system_prompt += (
f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
'Do not include any explanatory text before or after the JSON.\n\n'
)
# Add messages content
# First check for a system message
if messages and messages[0].role == 'system':
system_prompt = f'{messages[0].content}\n\n {system_prompt}'
messages = messages[1:]
# Add the rest of the messages
for m in messages:
m.content = self._clean_input(m.content)
gemini_messages.append(
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
max_output_tokens=max_tokens or self.max_tokens,
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
)
# Generate content using the simple string approach
response = await self.client.aio.models.generate_content(
model=self.model or DEFAULT_MODEL,
contents=gemini_messages,
config=generation_config,
)
# If this was a structured output request, parse the response into the Pydantic model
if response_model is not None:
try:
validated_model = response_model.model_validate(json.loads(response.text))
# Return as a dictionary for API consistency
return validated_model.model_dump()
except Exception as e:
raise Exception(f'Failed to parse structured response: {e}') from e
# Otherwise, return the response text as a dictionary
return {'content': response.text}
except Exception as e:
# Check if it's a rate limit error
if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
This method overrides the parent class method to provide a direct implementation.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
Returns:
dict[str, typing.Any]: The response from the language model.
"""
# Call the internal _generate_response method
return await self._generate_response(
messages=messages, response_model=response_model, max_tokens=max_tokens
)

View file

@ -16,63 +16,27 @@ limitations under the License.
from typing import Any, Protocol, TypedDict
from .dedupe_edges import (
Prompt as DedupeEdgesPrompt,
)
from .dedupe_edges import (
Versions as DedupeEdgesVersions,
)
from .dedupe_edges import (
versions as dedupe_edges_versions,
)
from .dedupe_nodes import (
Prompt as DedupeNodesPrompt,
)
from .dedupe_nodes import (
Versions as DedupeNodesVersions,
)
from .dedupe_nodes import (
versions as dedupe_nodes_versions,
)
from .dedupe_edges import Prompt as DedupeEdgesPrompt
from .dedupe_edges import Versions as DedupeEdgesVersions
from .dedupe_edges import versions as dedupe_edges_versions
from .dedupe_nodes import Prompt as DedupeNodesPrompt
from .dedupe_nodes import Versions as DedupeNodesVersions
from .dedupe_nodes import versions as dedupe_nodes_versions
from .eval import Prompt as EvalPrompt
from .eval import Versions as EvalVersions
from .eval import versions as eval_versions
from .extract_edge_dates import (
Prompt as ExtractEdgeDatesPrompt,
)
from .extract_edge_dates import (
Versions as ExtractEdgeDatesVersions,
)
from .extract_edge_dates import (
versions as extract_edge_dates_versions,
)
from .extract_edges import (
Prompt as ExtractEdgesPrompt,
)
from .extract_edges import (
Versions as ExtractEdgesVersions,
)
from .extract_edges import (
versions as extract_edges_versions,
)
from .extract_nodes import (
Prompt as ExtractNodesPrompt,
)
from .extract_nodes import (
Versions as ExtractNodesVersions,
)
from .extract_nodes import (
versions as extract_nodes_versions,
)
from .invalidate_edges import (
Prompt as InvalidateEdgesPrompt,
)
from .invalidate_edges import (
Versions as InvalidateEdgesVersions,
)
from .invalidate_edges import (
versions as invalidate_edges_versions,
)
from .extract_edge_dates import Prompt as ExtractEdgeDatesPrompt
from .extract_edge_dates import Versions as ExtractEdgeDatesVersions
from .extract_edge_dates import versions as extract_edge_dates_versions
from .extract_edges import Prompt as ExtractEdgesPrompt
from .extract_edges import Versions as ExtractEdgesVersions
from .extract_edges import versions as extract_edges_versions
from .extract_nodes import Prompt as ExtractNodesPrompt
from .extract_nodes import Versions as ExtractNodesVersions
from .extract_nodes import versions as extract_nodes_versions
from .invalidate_edges import Prompt as InvalidateEdgesPrompt
from .invalidate_edges import Versions as InvalidateEdgesVersions
from .invalidate_edges import versions as invalidate_edges_versions
from .models import Message, PromptFunction
from .prompt_helpers import DO_NOT_ESCAPE_UNICODE
from .summarize_nodes import Prompt as SummarizeNodesPrompt

View file

@ -1,8 +1,5 @@
from .edge_operations import build_episodic_edges, extract_edges
from .graph_data_operations import (
clear_data,
retrieve_episodes,
)
from .graph_data_operations import clear_data, retrieve_episodes
from .node_operations import extract_nodes
__all__ = [

View file

@ -9,11 +9,7 @@ from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
get_community_node_from_record,
)
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
from graphiti_core.utils.datetime_utils import utc_now

View file

@ -26,11 +26,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import (
EntityClassification,
ExtractedNodes,
MissedEntities,
)
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.utils.datetime_utils import utc_now

1753
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,11 @@ openai = "^1.53.0"
tenacity = "9.0.0"
numpy = ">=1.0.0"
python-dotenv = "^1.0.1"
anthropic = "~0.49.0"
[tool.poetry.extras]
anthropic = ["anthropic"]
groq = ["groq"]
google-genai = ["google-genai"]
[tool.poetry.group.dev.dependencies]
mypy = "^1.11.1"

View file

@ -1,12 +1,6 @@
from .common import Message, Result
from .ingest import AddEntityNodeRequest, AddMessagesRequest
from .retrieve import (
FactResult,
GetMemoryRequest,
GetMemoryResponse,
SearchQuery,
SearchResults,
)
from .retrieve import FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults
__all__ = [
'SearchQuery',

View file

@ -26,9 +26,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import (
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
)
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_CROSS_ENCODER
from graphiti_core.search.search_filters import SearchFilters
pytestmark = pytest.mark.integration