diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index f251e1dd..cee5b870 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -50,8 +50,7 @@ ENTITY_EDGE_RETURN: LiteralString = """ e.expired_at AS expired_at, e.valid_at AS valid_at, e.invalid_at AS invalid_at, - properties(e) AS attributes - """ + properties(e) AS attributes""" class Edge(BaseModel, ABC): @@ -303,21 +302,34 @@ class EntityEdge(Edge): group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, + with_embeddings: bool = False, ): cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' + with_embeddings_query: LiteralString = ( + """, + e.fact_embedding AS fact_embedding + """ + if with_embeddings + else '' + ) - records, _, _ = await driver.execute_query( + query: LiteralString = ( + """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - """ + cursor_query + ENTITY_EDGE_RETURN + + with_embeddings_query + """ ORDER BY e.uuid DESC """ - + limit_query, + + limit_query + ) + + records, _, _ = await driver.execute_query( + query, group_ids=group_ids, uuid=uuid_cursor, limit=limit, @@ -334,8 +346,8 @@ class EntityEdge(Edge): async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r') @@ -456,6 +468,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], fact=record['fact'], + fact_embedding=record.get('fact_embedding'), name=record['name'], group_id=record['group_id'], episodes=record['episodes'], diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 95b98ed2..d2fb79ca 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -46,8 +46,7 @@ ENTITY_NODE_RETURN: LiteralString = """ n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS attributes - """ + properties(n) AS attributes""" class EpisodeType(Enum): @@ -335,8 +334,8 @@ class EntityNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -374,9 +373,17 @@ class EntityNode(Node): group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, + with_embeddings: bool = False, ): cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' + with_embeddings_query: LiteralString = ( + """, + n.name_embedding AS name_embedding + """ + if with_embeddings + else '' + ) records, _, _ = await driver.execute_query( """ @@ -384,6 +391,7 @@ class EntityNode(Node): """ + cursor_query + ENTITY_NODE_RETURN + + with_embeddings_query + """ ORDER BY n.uuid DESC """ @@ -546,6 +554,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: entity_node = EntityNode( uuid=record['uuid'], name=record['name'], + name_embedding=record.get('name_embedding'), group_id=record['group_id'], labels=record['labels'], created_at=parse_db_date(record['created_at']), # type: ignore diff --git a/pyproject.toml b/pyproject.toml index aa75211e..833fcd90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.17.4" +version = "0.17.5" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index ca86c214..9842f430 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.17.4" +version = "0.17.5" source = { editable = "." } dependencies = [ { name = "diskcache" },