feat: Dedicated embedder interface (#159)
* feat: Add Embedder interface and implement openai embedder * feat: Add voyage ai embedder
This commit is contained in:
parent
790c37de38
commit
a7148d6260
18 changed files with 182 additions and 102 deletions
|
|
@ -24,9 +24,9 @@ from uuid import uuid4
|
|||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -171,17 +171,16 @@ class EntityEdge(Edge):
|
|||
default=None, description='datetime of when the fact stopped being true'
|
||||
)
|
||||
|
||||
async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
async def generate_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
|
||||
text = self.fact.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
||||
self.fact_embedding = await embedder.create(input=[text])
|
||||
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
return self.fact_embedding
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
|
|
|
|||
4
graphiti_core/embedder/__init__.py
Normal file
4
graphiti_core/embedder/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .client import EmbedderClient
|
||||
from .openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
|
||||
__all__ = ['EmbedderClient', 'OpenAIEmbedder', 'OpenAIEmbedderConfig']
|
||||
34
graphiti_core/embedder/client.py
Normal file
34
graphiti_core/embedder/client.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, List, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
EMBEDDING_DIM = 1024
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
|
||||
|
||||
|
||||
class EmbedderClient(ABC):
|
||||
@abstractmethod
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
pass
|
||||
48
graphiti_core/embedder/openai.py
Normal file
48
graphiti_core/embedder/openai.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, List
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import EmbeddingModel
|
||||
|
||||
from .client import EmbedderClient, EmbedderConfig
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
|
||||
|
||||
|
||||
class OpenAIEmbedderConfig(EmbedderConfig):
|
||||
embedding_model: EmbeddingModel | str = DEFAULT_EMBEDDING_MODEL
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
|
||||
|
||||
class OpenAIEmbedder(EmbedderClient):
|
||||
"""
|
||||
OpenAI Embedder Client
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenAIEmbedderConfig | None = None):
|
||||
if config is None:
|
||||
config = OpenAIEmbedderConfig()
|
||||
self.config = config
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
||||
return result.data[0].embedding[: self.config.embedding_dim]
|
||||
47
graphiti_core/embedder/voyage.py
Normal file
47
graphiti_core/embedder/voyage.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, List
|
||||
|
||||
import voyageai # type: ignore
|
||||
from pydantic import Field
|
||||
|
||||
from .client import EmbedderClient, EmbedderConfig
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = 'voyage-3'
|
||||
|
||||
|
||||
class VoyageAIEmbedderConfig(EmbedderConfig):
|
||||
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class VoyageAIEmbedder(EmbedderClient):
|
||||
"""
|
||||
VoyageAI Embedder Client
|
||||
"""
|
||||
|
||||
def __init__(self, config: VoyageAIEmbedderConfig | None = None):
|
||||
if config is None:
|
||||
config = VoyageAIEmbedderConfig()
|
||||
self.config = config
|
||||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
result = await self.client.embed(input, model=self.config.embedding_model)
|
||||
return result.embeddings[0][: self.config.embedding_dim]
|
||||
|
|
@ -23,6 +23,7 @@ from dotenv import load_dotenv
|
|||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
|
|
@ -83,6 +84,7 @@ class Graphiti:
|
|||
user: str,
|
||||
password: str,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
):
|
||||
"""
|
||||
|
|
@ -128,6 +130,10 @@ class Graphiti:
|
|||
self.llm_client = llm_client
|
||||
else:
|
||||
self.llm_client = OpenAIClient()
|
||||
if embedder:
|
||||
self.embedder = embedder
|
||||
else:
|
||||
self.embedder = OpenAIEmbedder()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
|
@ -290,7 +296,6 @@ class Graphiti:
|
|||
start = time()
|
||||
|
||||
entity_edges: list[EntityEdge] = []
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
|
|
@ -318,7 +323,7 @@ class Graphiti:
|
|||
# Calculate Embeddings
|
||||
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes]
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||
|
|
@ -346,7 +351,7 @@ class Graphiti:
|
|||
# calculate embeddings
|
||||
await asyncio.gather(
|
||||
*[
|
||||
edge.generate_embedding(embedder, self.llm_client.embedding_model)
|
||||
edge.generate_embedding(self.embedder)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
|
|
@ -439,7 +444,7 @@ class Graphiti:
|
|||
if update_communities:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
update_community(self.driver, self.llm_client, embedder, node)
|
||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||
for node in nodes
|
||||
]
|
||||
)
|
||||
|
|
@ -488,7 +493,6 @@ class Graphiti:
|
|||
"""
|
||||
try:
|
||||
start = time()
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
episodes = [
|
||||
|
|
@ -520,8 +524,8 @@ class Graphiti:
|
|||
|
||||
# Generate embeddings
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges],
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
|
|
@ -564,14 +568,14 @@ class Graphiti:
|
|||
raise e
|
||||
|
||||
async def build_communities(self):
|
||||
embedder = self.llm_client.get_embedder()
|
||||
|
||||
# Clear existing communities
|
||||
await remove_communities(self.driver)
|
||||
|
||||
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
||||
|
||||
await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes])
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
||||
)
|
||||
|
||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
|
|
@ -618,12 +622,11 @@ class Graphiti:
|
|||
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
||||
)
|
||||
search_config.limit = num_results
|
||||
search_config.embedding_model = self.llm_client.embedding_model
|
||||
|
||||
edges = (
|
||||
await search(
|
||||
self.driver,
|
||||
self.llm_client.get_embedder(),
|
||||
self.embedder,
|
||||
query,
|
||||
group_ids,
|
||||
search_config,
|
||||
|
|
@ -640,9 +643,7 @@ class Graphiti:
|
|||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
return await search(
|
||||
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
||||
)
|
||||
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
|
||||
|
||||
async def get_nodes_by_query(
|
||||
self,
|
||||
|
|
@ -687,14 +688,15 @@ class Graphiti:
|
|||
to each individual search method before results are combined and deduplicated.
|
||||
If not specified, a default limit (defined in the search functions) will be used.
|
||||
"""
|
||||
embedder = self.llm_client.get_embedder()
|
||||
search_config = (
|
||||
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
||||
)
|
||||
search_config.limit = limit
|
||||
|
||||
nodes = (
|
||||
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
||||
await search(
|
||||
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
|
||||
)
|
||||
).nodes
|
||||
return nodes
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import typing
|
|||
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
|
|
@ -47,10 +46,6 @@ class AnthropicClient(LLMClient):
|
|||
max_retries=1,
|
||||
)
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
openai_client = AsyncOpenAI()
|
||||
return openai_client.embeddings
|
||||
|
||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
system_message = messages[0]
|
||||
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
||||
|
|
|
|||
|
|
@ -54,11 +54,6 @@ class LLMClient(ABC):
|
|||
self.max_tokens = config.max_tokens
|
||||
self.cache_enabled = cache
|
||||
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
||||
self.embedding_model = config.embedding_model
|
||||
|
||||
@abstractmethod
|
||||
def get_embedder(self) -> typing.Any:
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(4),
|
||||
|
|
|
|||
|
|
@ -14,10 +14,9 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
EMBEDDING_DIM = 1024
|
||||
DEFAULT_MAX_TOKENS = 16384
|
||||
DEFAULT_TEMPERATURE = 0
|
||||
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
"""
|
||||
|
|
@ -33,7 +32,6 @@ class LLMConfig:
|
|||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
embedding_model: str | None = DEFAULT_EMBEDDING_MODEL,
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
):
|
||||
|
|
@ -51,15 +49,9 @@ class LLMConfig:
|
|||
base_url (str, optional): The base URL of the LLM API service.
|
||||
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
||||
This can be changed if using a different provider or a custom endpoint.
|
||||
embedding_model (str, optional): The specific embedding model.
|
||||
Defaults to openai "text-embedding-3-small" model.
|
||||
We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
if not embedding_model:
|
||||
embedding_model = DEFAULT_EMBEDDING_MODEL
|
||||
self.embedding_model = embedding_model
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import typing
|
|||
import groq
|
||||
from groq import AsyncGroq
|
||||
from groq.types.chat import ChatCompletionMessageParam
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
|
|
@ -44,10 +43,6 @@ class GroqClient(LLMClient):
|
|||
|
||||
self.client = AsyncGroq(api_key=config.api_key)
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
openai_client = AsyncOpenAI()
|
||||
return openai_client.embeddings
|
||||
|
||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
msgs: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
|
|
|
|||
|
|
@ -49,9 +49,6 @@ class OpenAIClient(LLMClient):
|
|||
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
||||
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
|
||||
get_embedder() -> typing.Any:
|
||||
Returns the embedder from the OpenAI client.
|
||||
|
||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||
Generates a response from the language model based on the provided messages.
|
||||
"""
|
||||
|
|
@ -78,9 +75,6 @@ class OpenAIClient(LLMClient):
|
|||
else:
|
||||
self.client = client
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
return self.client.embeddings
|
||||
|
||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
|
|
|
|||
|
|
@ -15,22 +15,18 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import typing
|
||||
from time import time
|
||||
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_embedding(
|
||||
embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL
|
||||
):
|
||||
async def generate_embedding(embedder: EmbedderClient, text: str):
|
||||
start = time()
|
||||
|
||||
text = text.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
embedding = embedding[:EMBEDDING_DIM]
|
||||
embedding = await embedder.create(input=[text])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ from uuid import uuid4
|
|||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -212,15 +212,14 @@ class EntityNode(Node):
|
|||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||
|
||||
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
async def generate_name_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
self.name_embedding = await embedder.create(input=[text])
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
return self.name_embedding
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
|
|
@ -323,15 +322,14 @@ class CommunityNode(Node):
|
|||
|
||||
return result
|
||||
|
||||
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
|
||||
async def generate_name_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
self.name_embedding = await embedder.create(input=[text])
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
return self.name_embedding
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ from time import time
|
|||
from neo4j import AsyncDriver
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.search.search_config import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
|
|
@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
|
|
@ -75,7 +75,6 @@ async def search(
|
|||
config.edge_config,
|
||||
center_node_uuid,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
node_search(
|
||||
driver,
|
||||
|
|
@ -85,7 +84,6 @@ async def search(
|
|||
config.node_config,
|
||||
center_node_uuid,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
community_search(
|
||||
driver,
|
||||
|
|
@ -94,7 +92,6 @@ async def search(
|
|||
group_ids,
|
||||
config.community_config,
|
||||
config.limit,
|
||||
config.embedding_model,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -113,13 +110,12 @@ async def search(
|
|||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -131,11 +127,7 @@ async def edge_search(
|
|||
search_results.append(text_search)
|
||||
|
||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
driver, search_vector, None, None, group_ids, 2 * limit
|
||||
|
|
@ -182,13 +174,12 @@ async def edge_search(
|
|||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -200,11 +191,7 @@ async def node_search(
|
|||
search_results.append(text_search)
|
||||
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await node_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
|
|
@ -236,12 +223,11 @@ async def node_search(
|
|||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -253,11 +239,7 @@ async def community_search(
|
|||
search_results.append(text_search)
|
||||
|
||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await community_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ class SearchConfig(BaseModel):
|
|||
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||
node_config: NodeSearchConfig | None = Field(default=None)
|
||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||
embedding_model: str | None = Field(default=None)
|
||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from neo4j import AsyncDriver
|
|||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||
from graphiti_core.prompts import prompt_library
|
||||
|
|
@ -288,7 +289,7 @@ async def determine_entity_community(
|
|||
|
||||
|
||||
async def update_community(
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||
):
|
||||
community, is_new = await determine_entity_community(driver, entity)
|
||||
|
||||
|
|
@ -305,6 +306,6 @@ async def update_community(
|
|||
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
||||
await community_edge.save(driver)
|
||||
|
||||
await community.generate_name_embedding(embedder, llm_client.embedding_model)
|
||||
await community.generate_name_embedding(embedder)
|
||||
|
||||
await community.save(driver)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti):
|
|||
group_id=group_id,
|
||||
summary=summary,
|
||||
)
|
||||
await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model)
|
||||
await new_node.generate_name_embedding(self.embedder)
|
||||
await new_node.save(self.driver)
|
||||
return new_node
|
||||
|
||||
|
|
@ -83,8 +83,7 @@ async def get_graphiti(settings: ZepEnvDep):
|
|||
client.llm_client.config.api_key = settings.openai_api_key
|
||||
if settings.model_name is not None:
|
||||
client.llm_client.model = settings.model_name
|
||||
if settings.embedding_model_name is not None:
|
||||
client.llm_client.embedding_model = settings.embedding_model_name
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ async def test_graphiti_init():
|
|||
@pytest.mark.asyncio
|
||||
async def test_graph_integration():
|
||||
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
embedder = client.llm_client.get_embedder()
|
||||
embedder = client.embedder
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now()
|
||||
|
|
@ -145,7 +145,7 @@ async def test_graph_integration():
|
|||
invalid_at=now,
|
||||
)
|
||||
|
||||
await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model)
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue