Return embeddings option in get_by_uuids (#736)
* add with_embeddings option * update
This commit is contained in:
parent
5d45d71259
commit
748464dfa5
4 changed files with 37 additions and 15 deletions
|
|
@ -50,8 +50,7 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
||||||
e.expired_at AS expired_at,
|
e.expired_at AS expired_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
e.invalid_at AS invalid_at,
|
e.invalid_at AS invalid_at,
|
||||||
properties(e) AS attributes
|
properties(e) AS attributes"""
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
|
|
@ -303,21 +302,34 @@ class EntityEdge(Edge):
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
with_embeddings: bool = False,
|
||||||
):
|
):
|
||||||
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
|
with_embeddings_query: LiteralString = (
|
||||||
|
""",
|
||||||
|
e.fact_embedding AS fact_embedding
|
||||||
|
"""
|
||||||
|
if with_embeddings
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query: LiteralString = (
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
|
WHERE e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
||||||
WHERE e.group_id IN $group_ids
|
|
||||||
"""
|
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
|
+ with_embeddings_query
|
||||||
+ """
|
+ """
|
||||||
ORDER BY e.uuid DESC
|
ORDER BY e.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query,
|
+ limit_query
|
||||||
|
)
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -334,8 +346,8 @@ class EntityEdge(Edge):
|
||||||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
||||||
|
|
@ -456,6 +468,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
source_node_uuid=record['source_node_uuid'],
|
source_node_uuid=record['source_node_uuid'],
|
||||||
target_node_uuid=record['target_node_uuid'],
|
target_node_uuid=record['target_node_uuid'],
|
||||||
fact=record['fact'],
|
fact=record['fact'],
|
||||||
|
fact_embedding=record.get('fact_embedding'),
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
episodes=record['episodes'],
|
episodes=record['episodes'],
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ ENTITY_NODE_RETURN: LiteralString = """
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS attributes
|
properties(n) AS attributes"""
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
|
|
@ -335,8 +334,8 @@ class EntityNode(Node):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
"""
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -374,9 +373,17 @@ class EntityNode(Node):
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
with_embeddings: bool = False,
|
||||||
):
|
):
|
||||||
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
|
with_embeddings_query: LiteralString = (
|
||||||
|
""",
|
||||||
|
n.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
if with_embeddings
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -384,6 +391,7 @@ class EntityNode(Node):
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
|
+ with_embeddings_query
|
||||||
+ """
|
+ """
|
||||||
ORDER BY n.uuid DESC
|
ORDER BY n.uuid DESC
|
||||||
"""
|
"""
|
||||||
|
|
@ -546,6 +554,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||||
entity_node = EntityNode(
|
entity_node = EntityNode(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
|
name_embedding=record.get('name_embedding'),
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
labels=record['labels'],
|
labels=record['labels'],
|
||||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.17.4"
|
version = "0.17.5"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.17.4"
|
version = "0.17.5"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue