Return embeddings option in get_by_uuids (#736)

* add with_embeddings option

* update
This commit is contained in:
Preston Rasmussen 2025-07-16 11:09:10 -04:00 committed by GitHub
parent 5d45d71259
commit 748464dfa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 15 deletions

View file

@ -50,8 +50,7 @@ ENTITY_EDGE_RETURN: LiteralString = """
e.expired_at AS expired_at, e.expired_at AS expired_at,
e.valid_at AS valid_at, e.valid_at AS valid_at,
e.invalid_at AS invalid_at, e.invalid_at AS invalid_at,
properties(e) AS attributes properties(e) AS attributes"""
"""
class Edge(BaseModel, ABC): class Edge(BaseModel, ABC):
@ -303,21 +302,34 @@ class EntityEdge(Edge):
group_ids: list[str], group_ids: list[str],
limit: int | None = None, limit: int | None = None,
uuid_cursor: str | None = None, uuid_cursor: str | None = None,
with_embeddings: bool = False,
): ):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
e.fact_embedding AS fact_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query( query: LiteralString = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query + cursor_query
+ ENTITY_EDGE_RETURN + ENTITY_EDGE_RETURN
+ with_embeddings_query
+ """ + """
ORDER BY e.uuid DESC ORDER BY e.uuid DESC
""" """
+ limit_query, + limit_query
)
records, _, _ = await driver.execute_query(
query,
group_ids=group_ids, group_ids=group_ids,
uuid=uuid_cursor, uuid=uuid_cursor,
limit=limit, limit=limit,
@ -334,8 +346,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
""" """
+ ENTITY_EDGE_RETURN + ENTITY_EDGE_RETURN
) )
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r') records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
@ -456,6 +468,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
source_node_uuid=record['source_node_uuid'], source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'], target_node_uuid=record['target_node_uuid'],
fact=record['fact'], fact=record['fact'],
fact_embedding=record.get('fact_embedding'),
name=record['name'], name=record['name'],
group_id=record['group_id'], group_id=record['group_id'],
episodes=record['episodes'], episodes=record['episodes'],

View file

@ -46,8 +46,7 @@ ENTITY_NODE_RETURN: LiteralString = """
n.created_at AS created_at, n.created_at AS created_at,
n.summary AS summary, n.summary AS summary,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS attributes properties(n) AS attributes"""
"""
class EpisodeType(Enum): class EpisodeType(Enum):
@ -335,8 +334,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = ( query = (
""" """
MATCH (n:Entity {uuid: $uuid}) MATCH (n:Entity {uuid: $uuid})
""" """
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
) )
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -374,9 +373,17 @@ class EntityNode(Node):
group_ids: list[str], group_ids: list[str],
limit: int | None = None, limit: int | None = None,
uuid_cursor: str | None = None, uuid_cursor: str | None = None,
with_embeddings: bool = False,
): ):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else '' cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
n.name_embedding AS name_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
@ -384,6 +391,7 @@ class EntityNode(Node):
""" """
+ cursor_query + cursor_query
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ with_embeddings_query
+ """ + """
ORDER BY n.uuid DESC ORDER BY n.uuid DESC
""" """
@ -546,6 +554,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
entity_node = EntityNode( entity_node = EntityNode(
uuid=record['uuid'], uuid=record['uuid'],
name=record['name'], name=record['name'],
name_embedding=record.get('name_embedding'),
group_id=record['group_id'], group_id=record['group_id'],
labels=record['labels'], labels=record['labels'],
created_at=parse_db_date(record['created_at']), # type: ignore created_at=parse_db_date(record['created_at']), # type: ignore

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.17.4" version = "0.17.5"
authors = [ authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.17.4" version = "0.17.5"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },