Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
paulpaliychuk
3efad4c225 chore: Pass embedding model in search utils 2024-09-26 16:20:01 -04:00
paulpaliychuk
e4bc756c31 Merge branch 'main' of github.com:getzep/graphiti into config-embedding-model
# Conflicts:
#	graphiti_core/search/search.py
2024-09-26 16:18:00 -04:00
paulpaliychuk
9a0971552a chore: Update comment 2024-09-26 14:46:10 -04:00
Arno
124363a3bc feat: configurable embedding model
format
2024-09-26 13:37:46 +08:00
12 changed files with 61 additions and 23 deletions

View file

@ -26,7 +26,7 @@ from pydantic import BaseModel, Field
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
@ -171,7 +171,7 @@ class EntityEdge(Edge):
default=None, description='datetime of when the fact stopped being true'
)
async def generate_embedding(self, embedder, model='text-embedding-3-small'):
async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
start = time()
text = self.fact.replace('\n', ' ')

View file

@ -318,7 +318,7 @@ class Graphiti:
# Calculate Embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes]
)
# Resolve extracted nodes with nodes already in the graph and extract facts
@ -346,7 +346,7 @@ class Graphiti:
# calculate embeddings
await asyncio.gather(
*[
edge.generate_embedding(embedder)
edge.generate_embedding(embedder, self.llm_client.embedding_model)
for edge in extracted_edges_with_resolved_pointers
]
)
@ -520,8 +520,8 @@ class Graphiti:
# Generate embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
*[edge.generate_embedding(embedder) for edge in extracted_edges],
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes],
*[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges],
)
# Dedupe extracted nodes, compress extracted edges
@ -571,7 +571,7 @@ class Graphiti:
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes])
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
@ -618,6 +618,7 @@ class Graphiti:
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = num_results
search_config.embedding_model = self.llm_client.embedding_model
edges = (
await search(

View file

@ -54,6 +54,7 @@ class LLMClient(ABC):
self.max_tokens = config.max_tokens
self.cache_enabled = cache
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
self.embedding_model = config.embedding_model
@abstractmethod
def get_embedder(self) -> typing.Any:

View file

@ -17,7 +17,7 @@ limitations under the License.
EMBEDDING_DIM = 1024
DEFAULT_MAX_TOKENS = 16384
DEFAULT_TEMPERATURE = 0
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
class LLMConfig:
"""
@ -33,6 +33,7 @@ class LLMConfig:
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
embedding_model: str | None = DEFAULT_EMBEDDING_MODEL,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
@ -50,9 +51,15 @@ class LLMConfig:
base_url (str, optional): The base URL of the LLM API service.
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
This can be changed if using a different provider or a custom endpoint.
embedding_model (str, optional): The specific embedding model.
Defaults to openai "text-embedding-3-small" model.
We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".
"""
self.base_url = base_url
self.api_key = api_key
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
if not embedding_model:
embedding_model = DEFAULT_EMBEDDING_MODEL
self.embedding_model = embedding_model

View file

@ -18,13 +18,13 @@ import logging
import typing
from time import time
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
logger = logging.getLogger(__name__)
async def generate_embedding(
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL
):
start = time()

View file

@ -26,7 +26,7 @@ from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
logger = logging.getLogger(__name__)
@ -212,7 +212,7 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
start = time()
text = self.name.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
@ -323,7 +323,7 @@ class CommunityNode(Node):
return result
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
start = time()
text = self.name.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding

View file

@ -23,7 +23,7 @@ from neo4j import AsyncDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.errors import SearchRerankerError
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
@ -68,12 +68,34 @@ async def search(
group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather(
edge_search(
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
driver,
embedder,
query,
group_ids,
config.edge_config,
center_node_uuid,
config.limit,
config.embedding_model,
),
node_search(
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
driver,
embedder,
query,
group_ids,
config.node_config,
center_node_uuid,
config.limit,
config.embedding_model,
),
community_search(
driver,
embedder,
query,
group_ids,
config.community_config,
config.limit,
config.embedding_model,
),
community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
)
results = SearchResults(
@ -97,6 +119,7 @@ async def edge_search(
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityEdge]:
if config is None:
return []
@ -109,7 +132,7 @@ async def edge_search(
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)
@ -165,6 +188,7 @@ async def node_search(
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityNode]:
if config is None:
return []
@ -177,7 +201,7 @@ async def node_search(
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)
@ -217,6 +241,7 @@ async def community_search(
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[CommunityNode]:
if config is None:
return []
@ -229,7 +254,7 @@ async def community_search(
if CommunitySearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)

View file

@ -74,6 +74,7 @@ class SearchConfig(BaseModel):
edge_config: EdgeSearchConfig | None = Field(default=None)
node_config: NodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
embedding_model: str | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)

View file

@ -305,6 +305,6 @@ async def update_community(
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
await community_edge.save(driver)
await community.generate_name_embedding(embedder)
await community.generate_name_embedding(embedder, llm_client.embedding_model)
await community.save(driver)

View file

@ -10,6 +10,7 @@ class Settings(BaseSettings):
openai_api_key: str
openai_base_url: str | None = Field(None)
model_name: str | None = Field(None)
embedding_model_name: str | None = Field(None)
neo4j_uri: str
neo4j_user: str
neo4j_password: str

View file

@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti):
group_id=group_id,
summary=summary,
)
await new_node.generate_name_embedding(self.llm_client.get_embedder())
await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model)
await new_node.save(self.driver)
return new_node
@ -83,6 +83,8 @@ async def get_graphiti(settings: ZepEnvDep):
client.llm_client.config.api_key = settings.openai_api_key
if settings.model_name is not None:
client.llm_client.model = settings.model_name
if settings.embedding_model_name is not None:
client.llm_client.embedding_model = settings.embedding_model_name
try:
yield client
finally:

View file

@ -145,7 +145,7 @@ async def test_graph_integration():
invalid_at=now,
)
await entity_edge.generate_embedding(embedder)
await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]