Compare commits
4 commits
main
...
config-emb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3efad4c225 | ||
|
|
e4bc756c31 | ||
|
|
9a0971552a | ||
|
|
124363a3bc |
12 changed files with 61 additions and 23 deletions
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||||
from graphiti_core.nodes import Node
|
from graphiti_core.nodes import Node
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -171,7 +171,7 @@ class EntityEdge(Edge):
|
||||||
default=None, description='datetime of when the fact stopped being true'
|
default=None, description='datetime of when the fact stopped being true'
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_embedding(self, embedder, model='text-embedding-3-small'):
|
async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
text = self.fact.replace('\n', ' ')
|
text = self.fact.replace('\n', ' ')
|
||||||
|
|
|
||||||
|
|
@ -318,7 +318,7 @@ class Graphiti:
|
||||||
# Calculate Embeddings
|
# Calculate Embeddings
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||||
|
|
@ -346,7 +346,7 @@ class Graphiti:
|
||||||
# calculate embeddings
|
# calculate embeddings
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
edge.generate_embedding(embedder)
|
edge.generate_embedding(embedder, self.llm_client.embedding_model)
|
||||||
for edge in extracted_edges_with_resolved_pointers
|
for edge in extracted_edges_with_resolved_pointers
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -520,8 +520,8 @@ class Graphiti:
|
||||||
|
|
||||||
# Generate embeddings
|
# Generate embeddings
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes],
|
||||||
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
*[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted nodes, compress extracted edges
|
# Dedupe extracted nodes, compress extracted edges
|
||||||
|
|
@ -571,7 +571,7 @@ class Graphiti:
|
||||||
|
|
||||||
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
||||||
|
|
||||||
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes])
|
||||||
|
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||||
|
|
@ -618,6 +618,7 @@ class Graphiti:
|
||||||
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
||||||
)
|
)
|
||||||
search_config.limit = num_results
|
search_config.limit = num_results
|
||||||
|
search_config.embedding_model = self.llm_client.embedding_model
|
||||||
|
|
||||||
edges = (
|
edges = (
|
||||||
await search(
|
await search(
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ class LLMClient(ABC):
|
||||||
self.max_tokens = config.max_tokens
|
self.max_tokens = config.max_tokens
|
||||||
self.cache_enabled = cache
|
self.cache_enabled = cache
|
||||||
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
||||||
|
self.embedding_model = config.embedding_model
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_embedder(self) -> typing.Any:
|
def get_embedder(self) -> typing.Any:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||||
EMBEDDING_DIM = 1024
|
EMBEDDING_DIM = 1024
|
||||||
DEFAULT_MAX_TOKENS = 16384
|
DEFAULT_MAX_TOKENS = 16384
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
|
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
|
||||||
|
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,6 +33,7 @@ class LLMConfig:
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
|
embedding_model: str | None = DEFAULT_EMBEDDING_MODEL,
|
||||||
temperature: float = DEFAULT_TEMPERATURE,
|
temperature: float = DEFAULT_TEMPERATURE,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
):
|
):
|
||||||
|
|
@ -50,9 +51,15 @@ class LLMConfig:
|
||||||
base_url (str, optional): The base URL of the LLM API service.
|
base_url (str, optional): The base URL of the LLM API service.
|
||||||
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
||||||
This can be changed if using a different provider or a custom endpoint.
|
This can be changed if using a different provider or a custom endpoint.
|
||||||
|
embedding_model (str, optional): The specific embedding model.
|
||||||
|
Defaults to openai "text-embedding-3-small" model.
|
||||||
|
We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
if not embedding_model:
|
||||||
|
embedding_model = DEFAULT_EMBEDDING_MODEL
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,13 @@ import logging
|
||||||
import typing
|
import typing
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def generate_embedding(
|
async def generate_embedding(
|
||||||
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
|
embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL
|
||||||
):
|
):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -212,7 +212,7 @@ class EntityNode(Node):
|
||||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||||
|
|
||||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||||
start = time()
|
start = time()
|
||||||
text = self.name.replace('\n', ' ')
|
text = self.name.replace('\n', ' ')
|
||||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||||
|
|
@ -323,7 +323,7 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||||
start = time()
|
start = time()
|
||||||
text = self.name.replace('\n', ' ')
|
text = self.name.replace('\n', ' ')
|
||||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from neo4j import AsyncDriver
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.errors import SearchRerankerError
|
from graphiti_core.errors import SearchRerankerError
|
||||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
from graphiti_core.search.search_config import (
|
from graphiti_core.search.search_config import (
|
||||||
DEFAULT_SEARCH_LIMIT,
|
DEFAULT_SEARCH_LIMIT,
|
||||||
|
|
@ -68,12 +68,34 @@ async def search(
|
||||||
group_ids = group_ids if group_ids else None
|
group_ids = group_ids if group_ids else None
|
||||||
edges, nodes, communities = await asyncio.gather(
|
edges, nodes, communities = await asyncio.gather(
|
||||||
edge_search(
|
edge_search(
|
||||||
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
driver,
|
||||||
|
embedder,
|
||||||
|
query,
|
||||||
|
group_ids,
|
||||||
|
config.edge_config,
|
||||||
|
center_node_uuid,
|
||||||
|
config.limit,
|
||||||
|
config.embedding_model,
|
||||||
),
|
),
|
||||||
node_search(
|
node_search(
|
||||||
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
driver,
|
||||||
|
embedder,
|
||||||
|
query,
|
||||||
|
group_ids,
|
||||||
|
config.node_config,
|
||||||
|
center_node_uuid,
|
||||||
|
config.limit,
|
||||||
|
config.embedding_model,
|
||||||
|
),
|
||||||
|
community_search(
|
||||||
|
driver,
|
||||||
|
embedder,
|
||||||
|
query,
|
||||||
|
group_ids,
|
||||||
|
config.community_config,
|
||||||
|
config.limit,
|
||||||
|
config.embedding_model,
|
||||||
),
|
),
|
||||||
community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = SearchResults(
|
results = SearchResults(
|
||||||
|
|
@ -97,6 +119,7 @@ async def edge_search(
|
||||||
config: EdgeSearchConfig | None,
|
config: EdgeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
embedding_model: str | None = None,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -109,7 +132,7 @@ async def edge_search(
|
||||||
|
|
||||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
search_vector = (
|
search_vector = (
|
||||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||||
.data[0]
|
.data[0]
|
||||||
.embedding[:EMBEDDING_DIM]
|
.embedding[:EMBEDDING_DIM]
|
||||||
)
|
)
|
||||||
|
|
@ -165,6 +188,7 @@ async def node_search(
|
||||||
config: NodeSearchConfig | None,
|
config: NodeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
embedding_model: str | None = None,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -177,7 +201,7 @@ async def node_search(
|
||||||
|
|
||||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
search_vector = (
|
search_vector = (
|
||||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||||
.data[0]
|
.data[0]
|
||||||
.embedding[:EMBEDDING_DIM]
|
.embedding[:EMBEDDING_DIM]
|
||||||
)
|
)
|
||||||
|
|
@ -217,6 +241,7 @@ async def community_search(
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: CommunitySearchConfig | None,
|
config: CommunitySearchConfig | None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
embedding_model: str | None = None,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -229,7 +254,7 @@ async def community_search(
|
||||||
|
|
||||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||||
search_vector = (
|
search_vector = (
|
||||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||||
.data[0]
|
.data[0]
|
||||||
.embedding[:EMBEDDING_DIM]
|
.embedding[:EMBEDDING_DIM]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,7 @@ class SearchConfig(BaseModel):
|
||||||
edge_config: EdgeSearchConfig | None = Field(default=None)
|
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||||
node_config: NodeSearchConfig | None = Field(default=None)
|
node_config: NodeSearchConfig | None = Field(default=None)
|
||||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||||
|
embedding_model: str | None = Field(default=None)
|
||||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -305,6 +305,6 @@ async def update_community(
|
||||||
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
||||||
await community_edge.save(driver)
|
await community_edge.save(driver)
|
||||||
|
|
||||||
await community.generate_name_embedding(embedder)
|
await community.generate_name_embedding(embedder, llm_client.embedding_model)
|
||||||
|
|
||||||
await community.save(driver)
|
await community.save(driver)
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ class Settings(BaseSettings):
|
||||||
openai_api_key: str
|
openai_api_key: str
|
||||||
openai_base_url: str | None = Field(None)
|
openai_base_url: str | None = Field(None)
|
||||||
model_name: str | None = Field(None)
|
model_name: str | None = Field(None)
|
||||||
|
embedding_model_name: str | None = Field(None)
|
||||||
neo4j_uri: str
|
neo4j_uri: str
|
||||||
neo4j_user: str
|
neo4j_user: str
|
||||||
neo4j_password: str
|
neo4j_password: str
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti):
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
)
|
)
|
||||||
await new_node.generate_name_embedding(self.llm_client.get_embedder())
|
await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model)
|
||||||
await new_node.save(self.driver)
|
await new_node.save(self.driver)
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
|
|
@ -83,6 +83,8 @@ async def get_graphiti(settings: ZepEnvDep):
|
||||||
client.llm_client.config.api_key = settings.openai_api_key
|
client.llm_client.config.api_key = settings.openai_api_key
|
||||||
if settings.model_name is not None:
|
if settings.model_name is not None:
|
||||||
client.llm_client.model = settings.model_name
|
client.llm_client.model = settings.model_name
|
||||||
|
if settings.embedding_model_name is not None:
|
||||||
|
client.llm_client.embedding_model = settings.embedding_model_name
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ async def test_graph_integration():
|
||||||
invalid_at=now,
|
invalid_at=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
await entity_edge.generate_embedding(embedder)
|
await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model)
|
||||||
|
|
||||||
nodes = [episode, alice_node, bob_node]
|
nodes = [episode, alice_node, bob_node]
|
||||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue