update
This commit is contained in:
parent
dd0d42855e
commit
c5ec136b1d
2 changed files with 22 additions and 26 deletions
|
|
@ -649,5 +649,5 @@ async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[En
|
|||
# filter out falsey values from edges
|
||||
filtered_edges = [edge for edge in edges if edge.fact]
|
||||
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
||||
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
|
||||
edge.fact_embedding = fact_embedding
|
||||
|
|
|
|||
|
|
@ -96,8 +96,7 @@ class Node(BaseModel, ABC):
|
|||
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, driver: GraphDriver):
|
||||
...
|
||||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
match driver.provider:
|
||||
|
|
@ -345,12 +344,10 @@ class Node(BaseModel, ABC):
|
|||
await driver.aoss_client.bulk(body=actions)
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
...
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
...
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
|
|
@ -438,11 +435,11 @@ class EpisodicNode(Node):
|
|||
|
||||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
|
|
@ -572,8 +569,7 @@ class EntityNode(Node):
|
|||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME,
|
||||
[entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
|
|
@ -622,12 +618,12 @@ class EntityNode(Node):
|
|||
|
||||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
with_embeddings: bool = False,
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
with_embeddings: bool = False,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
|
|
@ -767,11 +763,11 @@ class CommunityNode(Node):
|
|||
|
||||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
cls,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
|
|
@ -878,5 +874,5 @@ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[En
|
|||
# filter out falsey values from nodes
|
||||
filtered_nodes = [node for node in nodes if node.name]
|
||||
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
|
||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
|
||||
node.name_embedding = name_embedding
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue