Compare commits
4 commits
main
...
config-emb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3efad4c225 | ||
|
|
e4bc756c31 | ||
|
|
9a0971552a | ||
|
|
124363a3bc |
12 changed files with 61 additions and 23 deletions
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -171,7 +171,7 @@ class EntityEdge(Edge):
|
|||
default=None, description='datetime of when the fact stopped being true'
|
||||
)
|
||||
|
||||
async def generate_embedding(self, embedder, model='text-embedding-3-small'):
|
||||
async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
start = time()
|
||||
|
||||
text = self.fact.replace('\n', ' ')
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ class Graphiti:
|
|||
# Calculate Embeddings
|
||||
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||
|
|
@ -346,7 +346,7 @@ class Graphiti:
|
|||
# calculate embeddings
|
||||
await asyncio.gather(
|
||||
*[
|
||||
edge.generate_embedding(embedder)
|
||||
edge.generate_embedding(embedder, self.llm_client.embedding_model)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
|
|
@ -520,8 +520,8 @@ class Graphiti:
|
|||
|
||||
# Generate embeddings
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
||||
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges],
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
|
|
@ -571,7 +571,7 @@ class Graphiti:
|
|||
|
||||
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
||||
|
||||
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
||||
await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes])
|
||||
|
||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
|
|
@ -618,6 +618,7 @@ class Graphiti:
|
|||
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
||||
)
|
||||
search_config.limit = num_results
|
||||
search_config.embedding_model = self.llm_client.embedding_model
|
||||
|
||||
edges = (
|
||||
await search(
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class LLMClient(ABC):
|
|||
self.max_tokens = config.max_tokens
|
||||
self.cache_enabled = cache
|
||||
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
||||
self.embedding_model = config.embedding_model
|
||||
|
||||
@abstractmethod
|
||||
def get_embedder(self) -> typing.Any:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
EMBEDDING_DIM = 1024
|
||||
DEFAULT_MAX_TOKENS = 16384
|
||||
DEFAULT_TEMPERATURE = 0
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
|
||||
|
||||
class LLMConfig:
|
||||
"""
|
||||
|
|
@ -33,6 +33,7 @@ class LLMConfig:
|
|||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
embedding_model: str | None = DEFAULT_EMBEDDING_MODEL,
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
):
|
||||
|
|
@ -50,9 +51,15 @@ class LLMConfig:
|
|||
base_url (str, optional): The base URL of the LLM API service.
|
||||
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
||||
This can be changed if using a different provider or a custom endpoint.
|
||||
embedding_model (str, optional): The specific embedding model.
|
||||
Defaults to openai "text-embedding-3-small" model.
|
||||
We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
if not embedding_model:
|
||||
embedding_model = DEFAULT_EMBEDDING_MODEL
|
||||
self.embedding_model = embedding_model
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ import logging
|
|||
import typing
|
||||
from time import time
|
||||
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_embedding(
|
||||
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
|
||||
embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL
|
||||
):
|
||||
start = time()
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from neo4j import AsyncDriver
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ class EntityNode(Node):
|
|||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||
|
||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
||||
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
|
|
@ -323,7 +323,7 @@ class CommunityNode(Node):
|
|||
|
||||
return result
|
||||
|
||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
||||
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from neo4j import AsyncDriver
|
|||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.search.search_config import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
|
|
@ -68,12 +68,34 @@ async def search(
|
|||
group_ids = group_ids if group_ids else None
|
||||
edges, nodes, communities = await asyncio.gather(
|
||||
edge_search(
|
||||
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
group_ids,
|
||||
config.edge_config,
|
||||
center_node_uuid,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
node_search(
|
||||
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
group_ids,
|
||||
config.node_config,
|
||||
center_node_uuid,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
community_search(
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
group_ids,
|
||||
config.community_config,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
|
||||
)
|
||||
|
||||
results = SearchResults(
|
||||
|
|
@ -97,6 +119,7 @@ async def edge_search(
|
|||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -109,7 +132,7 @@ async def edge_search(
|
|||
|
||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
|
@ -165,6 +188,7 @@ async def node_search(
|
|||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -177,7 +201,7 @@ async def node_search(
|
|||
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
|
@ -217,6 +241,7 @@ async def community_search(
|
|||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -229,7 +254,7 @@ async def community_search(
|
|||
|
||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ class SearchConfig(BaseModel):
|
|||
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||
node_config: NodeSearchConfig | None = Field(default=None)
|
||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||
embedding_model: str | None = Field(default=None)
|
||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -305,6 +305,6 @@ async def update_community(
|
|||
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
||||
await community_edge.save(driver)
|
||||
|
||||
await community.generate_name_embedding(embedder)
|
||||
await community.generate_name_embedding(embedder, llm_client.embedding_model)
|
||||
|
||||
await community.save(driver)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class Settings(BaseSettings):
|
|||
openai_api_key: str
|
||||
openai_base_url: str | None = Field(None)
|
||||
model_name: str | None = Field(None)
|
||||
embedding_model_name: str | None = Field(None)
|
||||
neo4j_uri: str
|
||||
neo4j_user: str
|
||||
neo4j_password: str
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti):
|
|||
group_id=group_id,
|
||||
summary=summary,
|
||||
)
|
||||
await new_node.generate_name_embedding(self.llm_client.get_embedder())
|
||||
await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model)
|
||||
await new_node.save(self.driver)
|
||||
return new_node
|
||||
|
||||
|
|
@ -83,6 +83,8 @@ async def get_graphiti(settings: ZepEnvDep):
|
|||
client.llm_client.config.api_key = settings.openai_api_key
|
||||
if settings.model_name is not None:
|
||||
client.llm_client.model = settings.model_name
|
||||
if settings.embedding_model_name is not None:
|
||||
client.llm_client.embedding_model = settings.embedding_model_name
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ async def test_graph_integration():
|
|||
invalid_at=now,
|
||||
)
|
||||
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue