Return embeddings option in get_by_uuids (#736)

* add with_embeddings option

* update
This commit is contained in:
Preston Rasmussen 2025-07-16 11:09:10 -04:00 committed by GitHub
parent 5d45d71259
commit 748464dfa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 15 deletions

View file

@ -50,8 +50,7 @@ ENTITY_EDGE_RETURN: LiteralString = """
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
properties(e) AS attributes"""
class Edge(BaseModel, ABC):
@ -303,21 +302,34 @@ class EntityEdge(Edge):
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
e.fact_embedding AS fact_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query(
query: LiteralString = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ ENTITY_EDGE_RETURN
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
+ limit_query
)
records, _, _ = await driver.execute_query(
query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
@ -334,8 +346,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
@ -456,6 +468,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
fact_embedding=record.get('fact_embedding'),
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],

View file

@ -46,8 +46,7 @@ ENTITY_NODE_RETURN: LiteralString = """
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
properties(n) AS attributes"""
class EpisodeType(Enum):
@ -335,8 +334,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
@ -374,9 +373,17 @@ class EntityNode(Node):
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
n.name_embedding AS name_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query(
"""
@ -384,6 +391,7 @@ class EntityNode(Node):
"""
+ cursor_query
+ ENTITY_NODE_RETURN
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
"""
@ -546,6 +554,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record.get('name_embedding'),
group_id=record['group_id'],
labels=record['labels'],
created_at=parse_db_date(record['created_at']), # type: ignore

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.17.4"
version = "0.17.5"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.17.4"
version = "0.17.5"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },