use hnsw indexes (#859)

* use hnsw indexes

* add migration

* updates

* add group_id validation

* updates

* add type annotation

* updates

* update

* swap to prerelease
This commit is contained in:
Preston Rasmussen 2025-08-25 12:31:35 -04:00 committed by GitHub
parent b31e74e5d2
commit 0ac7ded4d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 189 additions and 26 deletions

View file

@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN, EPISODE_WINDOW_LEN,
build_dynamic_indexes,
build_indices_and_constraints, build_indices_and_constraints,
retrieve_episodes, retrieve_episodes,
) )
@ -450,6 +451,7 @@ class Graphiti:
validate_excluded_entity_types(excluded_entity_types, entity_types) validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id) validate_group_id(group_id)
await build_dynamic_indexes(self.driver, group_id)
previous_episodes = ( previous_episodes = (
await self.retrieve_episodes( await self.retrieve_episodes(

View file

View file

@ -0,0 +1,53 @@
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.helpers import validate_group_id
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
validate_group_id(group_id)
await build_dynamic_indexes(driver, group_id)
episode_query = """
MATCH (n:Episodic {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
entity_query = """
MATCH (n:Entity {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
community_query = """
MATCH (n:Community {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
async with driver.session() as session:
await session.run(
episode_query,
group_id=group_id,
group_label='Episodic_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
entity_query,
group_id=group_id,
group_label='Entity_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
community_query,
group_id=group_id,
group_label='Community_' + group_id.replace('-', ''),
batch_size=batch_size,
)

View file

@ -28,13 +28,21 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at} entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j and FalkorDB case GraphProvider.FALKORDB:
return """ return """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label)
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
@ -48,7 +56,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at} entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j and FalkorDB case GraphProvider.FALKORDB:
return """ return """
UNWIND $episodes AS episode UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid}) MERGE (n:Episodic {uuid: episode.uuid})
@ -56,6 +64,15 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n:$(episode.group_label)
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_RETURN = """ EPISODIC_NODE_RETURN = """

View file

@ -212,6 +212,7 @@ class EpisodicNode(Node):
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id, group_id=self.group_id,
group_label='Episodic_' + self.group_id.replace('-', ''),
source_description=self.source_description, source_description=self.source_description,
content=self.content, content=self.content,
entity_edges=self.entity_edges, entity_edges=self.entity_edges,
@ -380,7 +381,7 @@ class EntityNode(Node):
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels), get_entity_node_save_query(driver.provider, labels),

View file

@ -15,6 +15,7 @@ limitations under the License.
""" """
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from time import time from time import time
from typing import Any from typing import Any
@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
RELEVANT_SCHEMA_LIMIT = 10 RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6 DEFAULT_MIN_SCORE = 0.6
@ -178,11 +180,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
@ -477,11 +479,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -500,8 +502,14 @@ async def node_fulltext_search(
else: else:
return [] return []
else: else:
index_name = (
'node_name_and_summary'
if not USE_HNSW
else 'node_name_and_summary_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = ( query = (
get_nodes_query(driver.provider, 'node_name_and_summary', '$query') get_nodes_query(driver.provider, index_name, '$query')
+ """ + """
YIELD node AS n, score YIELD node AS n, score
WHERE n:Entity AND n.group_id IN $group_ids WHERE n:Entity AND n.group_id IN $group_ids
@ -585,11 +593,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -607,6 +615,36 @@ async def node_similarity_search(
) )
else: else:
return [] return []
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
index_name = 'group_entity_vector_' + (
group_ids[0].replace('-', '') if group_ids is not None else ''
)
query = (
f"""
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
"""
+ group_filter_query
+ filter_query
+ """
AND score > $min_score
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
else: else:
query = ( query = (
RUNTIME_QUERY RUNTIME_QUERY
@ -754,8 +792,14 @@ async def episode_fulltext_search(
else: else:
return [] return []
else: else:
index_name = (
'episode_content'
if not USE_HNSW
else 'episode_content_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = ( query = (
get_nodes_query(driver.provider, 'episode_content', '$query') get_nodes_query(driver.provider, index_name, '$query')
+ """ + """
YIELD node AS episode, score YIELD node AS episode, score
MATCH (e:Episodic) MATCH (e:Episodic)

View file

@ -116,6 +116,7 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes] episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes: for episode in episodes:
episode['source'] = str(episode['source'].value) episode['source'] = str(episode['source'].value)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes: list[dict[str, Any]] = [] nodes: list[dict[str, Any]] = []
for node in entity_nodes: for node in entity_nodes:
if node.name_embedding is None: if node.name_embedding is None:
@ -130,7 +131,9 @@ async def add_nodes_and_edges_bulk_tx(
} }
entity_data.update(node.attributes or {}) entity_data.update(node.attributes or {})
entity_data['labels'] = list(set(node.labels + ['Entity'])) entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data) nodes.append(entity_data)
edges: list[dict[str, Any]] = [] edges: list[dict[str, Any]] = []

View file

@ -114,9 +114,9 @@ async def retrieve_episodes(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time WHERE e.valid_at <= $reference_time
""" """
+ group_id_filter + group_id_filter
+ source_filter + source_filter
+ """ + """
@ -142,3 +142,46 @@ async def retrieve_episodes(
episodes = [get_episodic_node_from_record(record) for record in result] episodes = [get_episodic_node_from_record(record) for record in result]
return list(reversed(episodes)) # Return in chronological order return list(reversed(episodes)) # Return in chronological order
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
# Make sure indices exist for this group_id in Neo4j
if driver.provider == GraphProvider.NEO4J:
await semaphore_gather(
driver.execute_query(
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
FOR (e:"""
+ 'Episodic_'
+ group_id.replace('-', '')
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
episode_content='episode_content_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.summary, n.group_id]""",
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
FOR (n:"""
+ 'Community_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.group_id]""",
community_name='Community_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """)
ON n.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}""",
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
),
)

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.19.0" version = "0.20.0pre1"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.19.0" version = "0.20.0rc1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },