update new names with input_data (#204)

This commit is contained in:
Preston Rasmussen 2024-10-29 11:03:31 -04:00 committed by GitHub
parent 7bb0c78d5d
commit 63a1b11142
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 7 additions and 7 deletions

View file

@ -180,7 +180,7 @@ class EntityEdge(Edge):
start = time()
text = self.fact.replace('\n', ' ')
self.fact_embedding = await embedder.create(input=[text])
self.fact_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')

View file

@ -29,6 +29,6 @@ class EmbedderConfig(BaseModel):
class EmbedderClient(ABC):
@abstractmethod
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass

View file

@ -26,7 +26,7 @@ async def generate_embedding(embedder: EmbedderClient, text: str):
start = time()
text = text.replace('\n', ' ')
embedding = await embedder.create(input=[text])
embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')

View file

@ -222,7 +222,7 @@ class EntityNode(Node):
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input=[text])
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
@ -334,7 +334,7 @@ class CommunityNode(Node):
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input=[text])
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')

View file

@ -66,7 +66,7 @@ async def search(
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
start = time()
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.18"
version = "0.3.19"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",