Return embeddings option in get_by_uuids (#736)
* add with_embeddings option * update
This commit is contained in:
parent
5d45d71259
commit
748464dfa5
4 changed files with 37 additions and 15 deletions
|
|
@ -50,8 +50,7 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
|||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
properties(e) AS attributes"""
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
|
|
@ -303,21 +302,34 @@ class EntityEdge(Edge):
|
|||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
with_embeddings: bool = False,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
with_embeddings_query: LiteralString = (
|
||||
""",
|
||||
e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
if with_embeddings
|
||||
else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
+ limit_query
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
|
|
@ -334,8 +346,8 @@ class EntityEdge(Edge):
|
|||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
||||
|
|
@ -456,6 +468,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
fact_embedding=record.get('fact_embedding'),
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ ENTITY_NODE_RETURN: LiteralString = """
|
|||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
"""
|
||||
properties(n) AS attributes"""
|
||||
|
||||
|
||||
class EpisodeType(Enum):
|
||||
|
|
@ -335,8 +334,8 @@ class EntityNode(Node):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -374,9 +373,17 @@ class EntityNode(Node):
|
|||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
with_embeddings: bool = False,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
with_embeddings_query: LiteralString = (
|
||||
""",
|
||||
n.name_embedding AS name_embedding
|
||||
"""
|
||||
if with_embeddings
|
||||
else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -384,6 +391,7 @@ class EntityNode(Node):
|
|||
"""
|
||||
+ cursor_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
|
|
@ -546,6 +554,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||
entity_node = EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record.get('name_embedding'),
|
||||
group_id=record['group_id'],
|
||||
labels=record['labels'],
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.17.4"
|
||||
version = "0.17.5"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.17.4"
|
||||
version = "0.17.5"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue