update new names with input_data (#204)
This commit is contained in:
parent
7bb0c78d5d
commit
63a1b11142
6 changed files with 7 additions and 7 deletions
|
|
@ -180,7 +180,7 @@ class EntityEdge(Edge):
|
|||
start = time()
|
||||
|
||||
text = self.fact.replace('\n', ' ')
|
||||
self.fact_embedding = await embedder.create(input=[text])
|
||||
self.fact_embedding = await embedder.create(input_data=[text])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'embedded {text} in {end - start} ms')
|
||||
|
|
|
|||
|
|
@ -29,6 +29,6 @@ class EmbedderConfig(BaseModel):
|
|||
class EmbedderClient(ABC):
|
||||
@abstractmethod
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ async def generate_embedding(embedder: EmbedderClient, text: str):
|
|||
start = time()
|
||||
|
||||
text = text.replace('\n', ' ')
|
||||
embedding = await embedder.create(input=[text])
|
||||
embedding = await embedder.create(input_data=[text])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class EntityNode(Node):
|
|||
async def generate_name_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
self.name_embedding = await embedder.create(input=[text])
|
||||
self.name_embedding = await embedder.create(input_data=[text])
|
||||
end = time()
|
||||
logger.debug(f'embedded {text} in {end - start} ms')
|
||||
|
||||
|
|
@ -334,7 +334,7 @@ class CommunityNode(Node):
|
|||
async def generate_name_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
self.name_embedding = await embedder.create(input=[text])
|
||||
self.name_embedding = await embedder.create(input_data=[text])
|
||||
end = time()
|
||||
logger.debug(f'embedded {text} in {end - start} ms')
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ async def search(
|
|||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
||||
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.18"
|
||||
version = "0.3.19"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue