From 0ac7ded4d183bcda31dc35b6cd42ef9bb74957b7 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:31:35 -0400 Subject: [PATCH] use hnsw indexes (#859) * use hnsw indexes * add migration * updates * add group_id validation * updates * add type annotation * updates * update * swap to prerelease --- graphiti_core/graphiti.py | 2 + graphiti_core/migrations/__init__.py | 0 .../migrations/neo4j_node_group_labels.py | 53 +++++++++++++ graphiti_core/models/nodes/node_db_queries.py | 21 ++++- graphiti_core/nodes.py | 3 +- graphiti_core/search/search_utils.py | 78 +++++++++++++++---- graphiti_core/utils/bulk_utils.py | 5 +- .../maintenance/graph_data_operations.py | 49 +++++++++++- pyproject.toml | 2 +- uv.lock | 2 +- 10 files changed, 189 insertions(+), 26 deletions(-) create mode 100644 graphiti_core/migrations/__init__.py create mode 100644 graphiti_core/migrations/neo4j_node_group_labels.py diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 67f830fa..75d64f19 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import ( ) from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, + build_dynamic_indexes, build_indices_and_constraints, retrieve_episodes, ) @@ -450,6 +451,7 @@ class Graphiti: validate_excluded_entity_types(excluded_entity_types, entity_types) validate_group_id(group_id) + await build_dynamic_indexes(self.driver, group_id) previous_episodes = ( await self.retrieve_episodes( diff --git a/graphiti_core/migrations/__init__.py b/graphiti_core/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphiti_core/migrations/neo4j_node_group_labels.py b/graphiti_core/migrations/neo4j_node_group_labels.py new file mode 100644 index 00000000..f4cdb467 --- /dev/null +++ b/graphiti_core/migrations/neo4j_node_group_labels.py @@ -0,0 +1,53 @@ +from graphiti_core.driver.driver import GraphDriver +from graphiti_core.helpers import validate_group_id +from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes + + +async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100): + validate_group_id(group_id) + await build_dynamic_indexes(driver, group_id) + + episode_query = """ + MATCH (n:Episodic {group_id: $group_id}) + CALL { + WITH n + SET n:$group_label + } IN TRANSACTIONS OF $batch_size ROWS""" + + entity_query = """ + MATCH (n:Entity {group_id: $group_id}) + CALL { + WITH n + SET n:$group_label + } IN TRANSACTIONS OF $batch_size ROWS""" + + community_query = """ + MATCH (n:Community {group_id: $group_id}) + CALL { + WITH n + SET n:$group_label + } IN TRANSACTIONS OF $batch_size ROWS""" + + async with driver.session() as session: + await session.run( + episode_query, + group_id=group_id, + group_label='Episodic_' + group_id.replace('-', ''), + batch_size=batch_size, + ) + + async with driver.session() as session: + await session.run( + entity_query, + group_id=group_id, + group_label='Entity_' + group_id.replace('-', ''), + batch_size=batch_size, + ) + + async with driver.session() as session: + await session.run( + community_query, + group_id=group_id, + group_label='Community_' + group_id.replace('-', ''), + batch_size=batch_size, + ) diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 40a612d6..16f4031e 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -28,13 +28,21 @@ def get_episode_node_save_query(provider: GraphProvider) -> str: entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and FalkorDB + case GraphProvider.FALKORDB: return """ MERGE (n:Episodic {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ + case _: # Neo4j + return """ + MERGE (n:Episodic {uuid: $uuid}) + SET n:$($group_label) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, + entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} + RETURN n.uuid AS uuid + """ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: @@ -48,7 +56,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and FalkorDB + case GraphProvider.FALKORDB: return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -56,6 +64,15 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ + case _: # Neo4j + return """ + UNWIND $episodes AS episode + MERGE (n:Episodic {uuid: episode.uuid}) + SET n:$(episode.group_label) + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, + entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} + RETURN n.uuid AS uuid + """ EPISODIC_NODE_RETURN = """ diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 976a96f5..98bafab6 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -212,6 +212,7 @@ class EpisodicNode(Node): uuid=self.uuid, name=self.name, group_id=self.group_id, + group_label='Episodic_' + self.group_id.replace('-', ''), source_description=self.source_description, content=self.content, entity_edges=self.entity_edges, @@ -380,7 +381,7 @@ class EntityNode(Node): if driver.provider == GraphProvider.NEPTUNE: driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue - labels = ':'.join(self.labels + ['Entity']) + labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')]) result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels), diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 0c69d7e9..a24d36b7 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -15,6 +15,7 @@ limitations under the License. """ import logging +import os from collections import defaultdict from time import time from typing import Any @@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import ( ) logger = logging.getLogger(__name__) +USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes') RELEVANT_SCHEMA_LIMIT = 10 DEFAULT_MIN_SCORE = 0.6 @@ -178,11 +180,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m @@ -477,11 +479,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + ENTITY_NODE_RETURN + """ ORDER BY i.score DESC @@ -500,8 +502,14 @@ async def node_fulltext_search( else: return [] else: + index_name = ( + 'node_name_and_summary' + if not USE_HNSW + else 'node_name_and_summary_' + + (group_ids[0].replace('-', '') if group_ids is not None else '') + ) query = ( - get_nodes_query(driver.provider, 'node_name_and_summary', '$query') + get_nodes_query(driver.provider, index_name, '$query') + """ YIELD node AS n, score WHERE n:Entity AND n.group_id IN $group_ids @@ -585,11 +593,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + ENTITY_NODE_RETURN + """ ORDER BY i.score DESC @@ -607,6 +615,36 @@ async def node_similarity_search( ) else: return [] + elif driver.provider == GraphProvider.NEO4J and USE_HNSW: + index_name = 'group_entity_vector_' + ( + group_ids[0].replace('-', '') if group_ids is not None else '' + ) + query = ( + f""" + CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score + """ + + group_filter_query + + filter_query + + """ + AND score > $min_score + RETURN + """ + + ENTITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + + records, _, _ = await driver.execute_query( + query, + search_vector=search_vector, + limit=limit, + min_score=min_score, + routing_='r', + **query_params, + ) + else: query = ( RUNTIME_QUERY @@ -754,8 +792,14 @@ async def episode_fulltext_search( else: return [] else: + index_name = ( + 'episode_content' + if not USE_HNSW + else 'episode_content_' + + (group_ids[0].replace('-', '') if group_ids is not None else '') + ) query = ( - get_nodes_query(driver.provider, 'episode_content', '$query') + get_nodes_query(driver.provider, index_name, '$query') + """ YIELD node AS episode, score MATCH (e:Episodic) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 2e964eda..e20b20b0 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -116,6 +116,7 @@ async def add_nodes_and_edges_bulk_tx( episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: episode['source'] = str(episode['source'].value) + episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '') nodes: list[dict[str, Any]] = [] for node in entity_nodes: if node.name_embedding is None: @@ -130,7 +131,9 @@ async def add_nodes_and_edges_bulk_tx( } entity_data.update(node.attributes or {}) - entity_data['labels'] = list(set(node.labels + ['Entity'])) + entity_data['labels'] = list( + set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')]) + ) nodes.append(entity_data) edges: list[dict[str, Any]] = [] diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 4117c4ec..03950dd5 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -114,9 +114,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + group_id_filter + source_filter + """ @@ -142,3 +142,46 @@ async def retrieve_episodes( episodes = [get_episodic_node_from_record(record) for record in result] return list(reversed(episodes)) # Return in chronological order + + +async def build_dynamic_indexes(driver: GraphDriver, group_id: str): + # Make sure indices exist for this group_id in Neo4j + if driver.provider == GraphProvider.NEO4J: + await semaphore_gather( + driver.execute_query( + """CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS +FOR (e:""" + + 'Episodic_' + + group_id.replace('-', '') + + """) ON EACH [e.content, e.source, e.source_description, e.group_id]""", + episode_content='episode_content_' + group_id.replace('-', ''), + ), + driver.execute_query( + """CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:""" + + 'Entity_' + + group_id.replace('-', '') + + """) ON EACH [n.name, n.summary, n.group_id]""", + node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''), + ), + driver.execute_query( + """CREATE FULLTEXT INDEX $community_name IF NOT EXISTS + FOR (n:""" + + 'Community_' + + group_id.replace('-', '') + + """) ON EACH [n.name, n.group_id]""", + community_name='Community_' + group_id.replace('-', ''), + ), + driver.execute_query( + """CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS + FOR (n:""" + + 'Entity_' + + group_id.replace('-', '') + + """) + ON n.embedding + OPTIONS { indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }}""", + group_entity_vector='group_entity_vector_' + group_id.replace('-', ''), + ), + ) diff --git a/pyproject.toml b/pyproject.toml index f5bf5aa0..c469b88d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.19.0" +version = "0.20.0pre1" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index 230ce9f1..c96a7720 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.19.0" +version = "0.20.0rc1" source = { editable = "." } dependencies = [ { name = "diskcache" },