* cleanup

* update

* remove unused imports
This commit is contained in:
Preston Rasmussen 2025-09-05 11:30:46 -04:00 committed by GitHub
parent c0fcc82ebe
commit 1f5a1b890c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 64 additions and 284 deletions

View file

@ -89,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_dynamic_indexes,
build_indices_and_constraints,
retrieve_episodes,
)
@ -451,7 +450,6 @@ class Graphiti:
validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
await build_dynamic_indexes(self.driver, group_id)
previous_episodes = (
await self.retrieve_episodes(

View file

@ -1,114 +0,0 @@
import asyncio
import csv
import os
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.helpers import validate_group_id
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
validate_group_id(group_id)
await build_dynamic_indexes(driver, group_id)
episode_query = """
MATCH (n:Episodic {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
entity_query = """
MATCH (n:Entity {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
community_query = """
MATCH (n:Community {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
async with driver.session() as session:
await session.run(
episode_query,
group_id=group_id,
group_label='Episodic_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
entity_query,
group_id=group_id,
group_label='Entity_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
community_query,
group_id=group_id,
group_label='Community_' + group_id.replace('-', ''),
batch_size=batch_size,
)
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
with open(csv_file) as file:
reader = csv.reader(file)
group_ids = [row[0] for row in reader]
total_count = len(group_ids)
popped = group_ids[-count:]
remaining = group_ids[:-count]
with open(csv_file, 'w', newline='') as file:
writer = csv.writer(file)
for gid in remaining:
writer.writerow([gid])
return popped, total_count
async def get_group_ids(driver: GraphDriver):
query = """MATCH (n:Episodic)
RETURN DISTINCT n.group_id AS group_id"""
results, _, _ = await driver.execute_query(query)
group_ids = [result['group_id'] for result in results]
with open('group_ids.csv', 'w', newline='') as file:
writer = csv.writer(file)
for gid in group_ids:
writer.writerow([gid])
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
while len(group_ids) > 0:
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
async def main():
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
driver = Neo4jDriver(
uri=neo4j_uri,
user=neo4j_user,
password=neo4j_password,
)
await get_group_ids(driver)
await neo4j_node_label_migration(driver)
await driver.close()
if __name__ == '__main__':
asyncio.run(main())

View file

@ -52,7 +52,6 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label)
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
@ -96,7 +95,6 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n:$(episode.group_label)
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid

View file

@ -299,9 +299,6 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.provider == GraphProvider.NEO4J:
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
@ -471,7 +468,7 @@ class EntityNode(Node):
)
else:
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
labels = ':'.join(self.labels + ['Entity'])
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue

View file

@ -325,9 +325,7 @@ async def node_search(
search_tasks = []
if NodeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
node_fulltext_search(
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
)
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
@ -338,7 +336,6 @@ async def node_search(
group_ids,
2 * limit,
config.sim_min_score,
config.use_local_indexes,
)
)
if NodeSearchMethod.bfs in config.search_methods:
@ -434,9 +431,7 @@ async def episode_search(
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
episode_fulltext_search(
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
),
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
]
)
)

View file

@ -24,7 +24,6 @@ from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
MAX_SEARCH_DEPTH,
USE_HNSW,
)
DEFAULT_SEARCH_LIMIT = 10
@ -92,7 +91,6 @@ class NodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class EpisodeSearchConfig(BaseModel):
@ -101,7 +99,6 @@ class EpisodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class CommunitySearchConfig(BaseModel):
@ -110,7 +107,6 @@ class CommunitySearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class SearchConfig(BaseModel):

View file

@ -15,7 +15,6 @@ limitations under the License.
"""
import logging
import os
from collections import defaultdict
from time import time
from typing import Any
@ -57,7 +56,6 @@ from graphiti_core.search.search_filters import (
)
logger = logging.getLogger(__name__)
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
@ -210,11 +208,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -320,8 +318,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -540,7 +538,6 @@ async def node_fulltext_search(
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver)
@ -574,11 +571,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -596,14 +593,10 @@ async def node_fulltext_search(
else:
return []
else:
index_name = (
'node_name_and_summary'
if not use_local_indexes
else 'node_name_and_summary_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
get_nodes_query(
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
)
+ yield_query
+ filter_query
+ """
@ -635,7 +628,6 @@ async def node_similarity_search(
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
use_local_indexes: bool = False,
) -> list[EntityNode]:
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
@ -656,8 +648,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -686,11 +678,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -708,40 +700,11 @@ async def node_similarity_search(
)
else:
return []
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
index_name = 'group_entity_vector_' + (
group_ids[0].replace('-', '') if group_ids is not None else ''
)
query = (
f"""
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
"""
+ filter_query
+ """
AND score > $min_score
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -865,7 +828,6 @@ async def episode_fulltext_search(
_search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver)
@ -915,14 +877,8 @@ async def episode_fulltext_search(
else:
return []
else:
index_name = (
'episode_content'
if not use_local_indexes
else 'episode_content_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
+ """
YIELD node AS episode, score
MATCH (e:Episodic)
@ -1047,8 +1003,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1107,8 +1063,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1250,9 +1206,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1297,9 +1253,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1388,9 +1344,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1460,9 +1416,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1498,9 +1454,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1573,10 +1529,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1646,10 +1602,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1685,10 +1641,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -119,8 +119,6 @@ async def add_nodes_and_edges_bulk_tx(
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
if driver.provider == GraphProvider.NEO4J:
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes = []
@ -143,9 +141,6 @@ async def add_nodes_and_edges_bulk_tx(
entity_data['attributes'] = json.dumps(attributes)
else:
entity_data.update(node.attributes or {})
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data)

View file

@ -149,9 +149,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN
@ -175,44 +175,3 @@ async def retrieve_episodes(
episodes = [get_episodic_node_from_record(record) for record in result]
return list(reversed(episodes)) # Return in chronological order
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
# Make sure indices exist for this group_id in Neo4j
if driver.provider == GraphProvider.NEO4J:
await driver.execute_query(
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
FOR (e:"""
+ 'Episodic_'
+ group_id.replace('-', '')
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
episode_content='episode_content_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.summary, n.group_id]""",
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
FOR (n:"""
+ 'Community_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.group_id]""",
community_name='Community_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """)
ON n.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}""",
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.20.1"
version = "0.20.2"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.20.0"
version = "0.20.2"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },