filter out falsey values
This commit is contained in:
parent
443f972f45
commit
dd0d42855e
4 changed files with 33 additions and 25 deletions
|
|
@ -646,6 +646,8 @@ def get_community_edge_from_record(record: Any):
|
||||||
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
return
|
return
|
||||||
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
# filter out falsey values from edges
|
||||||
|
filtered_edges = [edge for edge in edges if edge.fact]
|
||||||
|
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
||||||
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||||
edge.fact_embedding = fact_embedding
|
edge.fact_embedding = fact_embedding
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,8 @@ class Node(BaseModel, ABC):
|
||||||
created_at: datetime = Field(default_factory=lambda: utc_now())
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: GraphDriver): ...
|
async def save(self, driver: GraphDriver):
|
||||||
|
...
|
||||||
|
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
match driver.provider:
|
match driver.provider:
|
||||||
|
|
@ -344,10 +345,12 @@ class Node(BaseModel, ABC):
|
||||||
await driver.aoss_client.bulk(body=actions)
|
await driver.aoss_client.bulk(body=actions)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class EpisodicNode(Node):
|
class EpisodicNode(Node):
|
||||||
|
|
@ -435,11 +438,11 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: GraphDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
):
|
):
|
||||||
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
|
|
@ -569,7 +572,8 @@ class EntityNode(Node):
|
||||||
labels = ':'.join(self.labels + ['Entity'])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
await driver.save_to_aoss(ENTITY_INDEX_NAME,
|
||||||
|
[entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||||
|
|
@ -618,12 +622,12 @@ class EntityNode(Node):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: GraphDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
with_embeddings: bool = False,
|
with_embeddings: bool = False,
|
||||||
):
|
):
|
||||||
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
|
|
@ -763,11 +767,11 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: GraphDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
):
|
):
|
||||||
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
|
|
@ -871,6 +875,8 @@ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[En
|
||||||
if not nodes: # Handle empty list case
|
if not nodes: # Handle empty list case
|
||||||
return
|
return
|
||||||
|
|
||||||
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
# filter out falsey values from nodes
|
||||||
|
filtered_nodes = [node for node in nodes if node.name]
|
||||||
|
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
|
||||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||||
node.name_embedding = name_embedding
|
node.name_embedding = name_embedding
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.21.0pre10"
|
version = "0.21.0pre11"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 3
|
revision = 2
|
||||||
requires-python = ">=3.10, <4"
|
requires-python = ">=3.10, <4"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14'",
|
"python_full_version >= '3.14'",
|
||||||
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.21.0rc8"
|
version = "0.21.0rc11"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue