From 432d2295c62aee0d6a4c20d79fb41570f4fd21f6 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 22 Apr 2025 12:03:09 -0400 Subject: [PATCH] Revert episodes (#387) * episode search fixes and optimizations * remove extra return string * Update graphiti_core/utils/maintenance/graph_data_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- graphiti_core/nodes.py | 39 ++++++++++++++----- graphiti_core/search/search_utils.py | 30 ++++---------- .../maintenance/graph_data_operations.py | 18 ++++++--- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 9dfea464..ed68958a 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -39,8 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) ENTITY_NODE_RETURN: LiteralString = """ - OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n) - WITH n, collect(e.uuid) AS episodes RETURN n.uuid As uuid, n.name AS name, @@ -49,8 +47,8 @@ ENTITY_NODE_RETURN: LiteralString = """ n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS attributes, - episodes""" + properties(n) AS attributes + """ class EpisodeType(Enum): @@ -265,13 +263,35 @@ class EpisodicNode(Node): return episodes + @classmethod + async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) + RETURN DISTINCT + e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id, + e.source_description AS source_description, + e.source AS source, + e.entity_edges AS entity_edges + """, + entity_node_uuid=entity_node_uuid, + database_=DEFAULT_DATABASE, + routing_='r', + ) + + episodes = [get_episodic_node_from_record(record) for record in records] + + return episodes + class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - episodes: list[str] | None = Field( - default=None, description='List of episode uuids that mention this node.' - ) attributes: dict[str, Any] = Field( default={}, description='Additional attributes of the node. Dependent on node labels' ) @@ -312,8 +332,8 @@ class EntityNode(Node): async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -519,7 +539,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode: created_at=record['created_at'].to_native(), summary=record['summary'], attributes=record['attributes'], - episodes=record['episodes'], ) entity_node.attributes.pop('uuid', None) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 4095a88b..9a8079eb 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -54,20 +54,6 @@ DEFAULT_MMR_LAMBDA = 0.5 MAX_SEARCH_DEPTH = 3 MAX_QUERY_LENGTH = 32 -SEARCH_ENTITY_NODE_RETURN: LiteralString = """ - OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n) - WITH n, score, collect(e.uuid) AS episodes - RETURN - n.uuid As uuid, - n.name AS name, - n.name_embedding AS name_embedding, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes, - episodes""" - def fulltext_query(query: str, group_ids: list[str] | None = None): group_ids_filter_list = ( @@ -245,8 +231,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -358,12 +344,12 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query - + SEARCH_ENTITY_NODE_RETURN + + ENTITY_NODE_RETURN + """ ORDER BY score DESC """ @@ -416,7 +402,7 @@ async def node_similarity_search( + """ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WHERE score > $min_score""" - + SEARCH_ENTITY_NODE_RETURN + + ENTITY_NODE_RETURN + """ ORDER BY score DESC LIMIT $limit diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 23fcafb8..32e64a30 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -132,10 +132,14 @@ async def retrieve_episodes( Returns: list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. """ - result = await driver.execute_query( + group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' + + query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - AND ($group_ids IS NULL) OR e.group_id in $group_ids + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time + """ + + group_id_filter + + """ RETURN e.content AS content, e.created_at AS created_at, e.valid_at AS valid_at, @@ -144,9 +148,13 @@ async def retrieve_episodes( e.name AS name, e.source_description AS source_description, e.source AS source - ORDER BY e.created_at DESC + ORDER BY e.valid_at DESC LIMIT $num_episodes - """, + """ + ) + + result = await driver.execute_query( + query, reference_time=reference_time, num_episodes=last_n, group_ids=group_ids,