use hnsw indexes (#859)
* use hnsw indexes * add migration * updates * add group_id validation * updates * add type annotation * updates * update * swap to prerelease
This commit is contained in:
parent
b31e74e5d2
commit
0ac7ded4d1
10 changed files with 189 additions and 26 deletions
|
|
@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_dynamic_indexes,
|
||||
build_indices_and_constraints,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
|
@ -450,6 +451,7 @@ class Graphiti:
|
|||
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
await build_dynamic_indexes(self.driver, group_id)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
|
|
|
|||
0
graphiti_core/migrations/__init__.py
Normal file
0
graphiti_core/migrations/__init__.py
Normal file
53
graphiti_core/migrations/neo4j_node_group_labels.py
Normal file
53
graphiti_core/migrations/neo4j_node_group_labels.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.helpers import validate_group_id
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
|
||||
|
||||
|
||||
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||
validate_group_id(group_id)
|
||||
await build_dynamic_indexes(driver, group_id)
|
||||
|
||||
episode_query = """
|
||||
MATCH (n:Episodic {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
entity_query = """
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
community_query = """
|
||||
MATCH (n:Community {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
episode_query,
|
||||
group_id=group_id,
|
||||
group_label='Episodic_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
entity_query,
|
||||
group_id=group_id,
|
||||
group_label='Entity_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
community_query,
|
||||
group_id=group_id,
|
||||
group_label='Community_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
|
@ -28,13 +28,21 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j and FalkorDB
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n:$($group_label)
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||
|
|
@ -48,7 +56,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j and FalkorDB
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
|
|
@ -56,6 +64,15 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n:$(episode.group_label)
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
EPISODIC_NODE_RETURN = """
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ class EpisodicNode(Node):
|
|||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
group_label='Episodic_' + self.group_id.replace('-', ''),
|
||||
source_description=self.source_description,
|
||||
content=self.content,
|
||||
entity_edges=self.entity_edges,
|
||||
|
|
@ -380,7 +381,7 @@ class EntityNode(Node):
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
|
@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import (
|
|||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 10
|
||||
DEFAULT_MIN_SCORE = 0.6
|
||||
|
|
@ -178,11 +180,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
||||
|
|
@ -477,11 +479,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -500,8 +502,14 @@ async def node_fulltext_search(
|
|||
else:
|
||||
return []
|
||||
else:
|
||||
index_name = (
|
||||
'node_name_and_summary'
|
||||
if not USE_HNSW
|
||||
else 'node_name_and_summary_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
query = (
|
||||
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
||||
get_nodes_query(driver.provider, index_name, '$query')
|
||||
+ """
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity AND n.group_id IN $group_ids
|
||||
|
|
@ -585,11 +593,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -607,6 +615,36 @@ async def node_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
|
||||
index_name = 'group_entity_vector_' + (
|
||||
group_ids[0].replace('-', '') if group_ids is not None else ''
|
||||
)
|
||||
query = (
|
||||
f"""
|
||||
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """
|
||||
AND score > $min_score
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
|
||||
else:
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
|
|
@ -754,8 +792,14 @@ async def episode_fulltext_search(
|
|||
else:
|
||||
return []
|
||||
else:
|
||||
index_name = (
|
||||
'episode_content'
|
||||
if not USE_HNSW
|
||||
else 'episode_content_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
query = (
|
||||
get_nodes_query(driver.provider, 'episode_content', '$query')
|
||||
get_nodes_query(driver.provider, index_name, '$query')
|
||||
+ """
|
||||
YIELD node AS episode, score
|
||||
MATCH (e:Episodic)
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
||||
nodes: list[dict[str, Any]] = []
|
||||
for node in entity_nodes:
|
||||
if node.name_embedding is None:
|
||||
|
|
@ -130,7 +131,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
}
|
||||
|
||||
entity_data.update(node.attributes or {})
|
||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
entity_data['labels'] = list(
|
||||
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
||||
)
|
||||
nodes.append(entity_data)
|
||||
|
||||
edges: list[dict[str, Any]] = []
|
||||
|
|
|
|||
|
|
@ -114,9 +114,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
+ """
|
||||
|
|
@ -142,3 +142,46 @@ async def retrieve_episodes(
|
|||
|
||||
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||
return list(reversed(episodes)) # Return in chronological order
|
||||
|
||||
|
||||
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
|
||||
# Make sure indices exist for this group_id in Neo4j
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
await semaphore_gather(
|
||||
driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
|
||||
FOR (e:"""
|
||||
+ 'Episodic_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
episode_content='episode_content_' + group_id.replace('-', ''),
|
||||
),
|
||||
driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
|
||||
+ 'Entity_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
|
||||
),
|
||||
driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
|
||||
FOR (n:"""
|
||||
+ 'Community_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [n.name, n.group_id]""",
|
||||
community_name='Community_' + group_id.replace('-', ''),
|
||||
),
|
||||
driver.execute_query(
|
||||
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
|
||||
FOR (n:"""
|
||||
+ 'Entity_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """)
|
||||
ON n.embedding
|
||||
OPTIONS { indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}""",
|
||||
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0pre1"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue