use hnsw indexes (#859)

* use hnsw indexes

* add migration

* updates

* add group_id validation

* updates

* add type annotation

* updates

* update

* swap to prerelease
This commit is contained in:
Preston Rasmussen 2025-08-25 12:31:35 -04:00 committed by GitHub
parent b31e74e5d2
commit 0ac7ded4d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 189 additions and 26 deletions

View file

@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_dynamic_indexes,
build_indices_and_constraints,
retrieve_episodes,
)
@ -450,6 +451,7 @@ class Graphiti:
validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
await build_dynamic_indexes(self.driver, group_id)
previous_episodes = (
await self.retrieve_episodes(

View file

View file

@ -0,0 +1,53 @@
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.helpers import validate_group_id
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
validate_group_id(group_id)
await build_dynamic_indexes(driver, group_id)
episode_query = """
MATCH (n:Episodic {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
entity_query = """
MATCH (n:Entity {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
community_query = """
MATCH (n:Community {group_id: $group_id})
CALL {
WITH n
SET n:$group_label
} IN TRANSACTIONS OF $batch_size ROWS"""
async with driver.session() as session:
await session.run(
episode_query,
group_id=group_id,
group_label='Episodic_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
entity_query,
group_id=group_id,
group_label='Entity_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
community_query,
group_id=group_id,
group_label='Community_' + group_id.replace('-', ''),
batch_size=batch_size,
)

View file

@ -28,13 +28,21 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
case GraphProvider.FALKORDB:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label)
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
@ -48,7 +56,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
case GraphProvider.FALKORDB:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
@ -56,6 +64,15 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n:$(episode.group_label)
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_RETURN = """

View file

@ -212,6 +212,7 @@ class EpisodicNode(Node):
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
group_label='Episodic_' + self.group_id.replace('-', ''),
source_description=self.source_description,
content=self.content,
entity_edges=self.entity_edges,
@ -380,7 +381,7 @@ class EntityNode(Node):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
labels = ':'.join(self.labels + ['Entity'])
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),

View file

@ -15,6 +15,7 @@ limitations under the License.
"""
import logging
import os
from collections import defaultdict
from time import time
from typing import Any
@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import (
)
logger = logging.getLogger(__name__)
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
@ -178,11 +180,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
@ -477,11 +479,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY i.score DESC
@ -500,8 +502,14 @@ async def node_fulltext_search(
else:
return []
else:
index_name = (
'node_name_and_summary'
if not USE_HNSW
else 'node_name_and_summary_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
get_nodes_query(driver.provider, index_name, '$query')
+ """
YIELD node AS n, score
WHERE n:Entity AND n.group_id IN $group_ids
@ -585,11 +593,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY i.score DESC
@ -607,6 +615,36 @@ async def node_similarity_search(
)
else:
return []
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
index_name = 'group_entity_vector_' + (
group_ids[0].replace('-', '') if group_ids is not None else ''
)
query = (
f"""
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
"""
+ group_filter_query
+ filter_query
+ """
AND score > $min_score
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
else:
query = (
RUNTIME_QUERY
@ -754,8 +792,14 @@ async def episode_fulltext_search(
else:
return []
else:
index_name = (
'episode_content'
if not USE_HNSW
else 'episode_content_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
get_nodes_query(driver.provider, 'episode_content', '$query')
get_nodes_query(driver.provider, index_name, '$query')
+ """
YIELD node AS episode, score
MATCH (e:Episodic)

View file

@ -116,6 +116,7 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes: list[dict[str, Any]] = []
for node in entity_nodes:
if node.name_embedding is None:
@ -130,7 +131,9 @@ async def add_nodes_and_edges_bulk_tx(
}
entity_data.update(node.attributes or {})
entity_data['labels'] = list(set(node.labels + ['Entity']))
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data)
edges: list[dict[str, Any]] = []

View file

@ -114,9 +114,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ group_id_filter
+ source_filter
+ """
@ -142,3 +142,46 @@ async def retrieve_episodes(
episodes = [get_episodic_node_from_record(record) for record in result]
return list(reversed(episodes)) # Return in chronological order
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
# Make sure indices exist for this group_id in Neo4j
if driver.provider == GraphProvider.NEO4J:
await semaphore_gather(
driver.execute_query(
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
FOR (e:"""
+ 'Episodic_'
+ group_id.replace('-', '')
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
episode_content='episode_content_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.summary, n.group_id]""",
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
FOR (n:"""
+ 'Community_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.group_id]""",
community_name='Community_' + group_id.replace('-', ''),
),
driver.execute_query(
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """)
ON n.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}""",
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
),
)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.19.0"
version = "0.20.0pre1"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.19.0"
version = "0.20.0rc1"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },