* cleanup

* update

* remove unused imports
This commit is contained in:
Preston Rasmussen 2025-09-05 11:30:46 -04:00 committed by GitHub
parent c0fcc82ebe
commit 1f5a1b890c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 64 additions and 284 deletions

View file

@ -89,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN, EPISODE_WINDOW_LEN,
build_dynamic_indexes,
build_indices_and_constraints, build_indices_and_constraints,
retrieve_episodes, retrieve_episodes,
) )
@ -451,7 +450,6 @@ class Graphiti:
validate_excluded_entity_types(excluded_entity_types, entity_types) validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id) validate_group_id(group_id)
await build_dynamic_indexes(self.driver, group_id)
previous_episodes = ( previous_episodes = (
await self.retrieve_episodes( await self.retrieve_episodes(

View file

@ -1,114 +0,0 @@
import asyncio
import csv
import os
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.helpers import validate_group_id
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
validate_group_id(group_id)
await build_dynamic_indexes(driver, group_id)
episode_query = """
MATCH (n:Episodic {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
entity_query = """
MATCH (n:Entity {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
community_query = """
MATCH (n:Community {group_id: $group_id})
CALL {
WITH n
SET n:$($group_label)
} IN TRANSACTIONS OF $batch_size ROWS"""
async with driver.session() as session:
await session.run(
episode_query,
group_id=group_id,
group_label='Episodic_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
entity_query,
group_id=group_id,
group_label='Entity_' + group_id.replace('-', ''),
batch_size=batch_size,
)
async with driver.session() as session:
await session.run(
community_query,
group_id=group_id,
group_label='Community_' + group_id.replace('-', ''),
batch_size=batch_size,
)
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
with open(csv_file) as file:
reader = csv.reader(file)
group_ids = [row[0] for row in reader]
total_count = len(group_ids)
popped = group_ids[-count:]
remaining = group_ids[:-count]
with open(csv_file, 'w', newline='') as file:
writer = csv.writer(file)
for gid in remaining:
writer.writerow([gid])
return popped, total_count
async def get_group_ids(driver: GraphDriver):
query = """MATCH (n:Episodic)
RETURN DISTINCT n.group_id AS group_id"""
results, _, _ = await driver.execute_query(query)
group_ids = [result['group_id'] for result in results]
with open('group_ids.csv', 'w', newline='') as file:
writer = csv.writer(file)
for gid in group_ids:
writer.writerow([gid])
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
while len(group_ids) > 0:
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
async def main():
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
driver = Neo4jDriver(
uri=neo4j_uri,
user=neo4j_user,
password=neo4j_password,
)
await get_group_ids(driver)
await neo4j_node_label_migration(driver)
await driver.close()
if __name__ == '__main__':
asyncio.run(main())

View file

@ -52,7 +52,6 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
case _: # Neo4j case _: # Neo4j
return """ return """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label)
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
@ -96,7 +95,6 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """ return """
UNWIND $episodes AS episode UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid}) MERGE (n:Episodic {uuid: episode.uuid})
SET n:$(episode.group_label)
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid

View file

@ -299,9 +299,6 @@ class EpisodicNode(Node):
'source': self.source.value, 'source': self.source.value,
} }
if driver.provider == GraphProvider.NEO4J:
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query( result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args get_episode_node_save_query(driver.provider), **episode_args
) )
@ -471,7 +468,7 @@ class EntityNode(Node):
) )
else: else:
entity_data.update(self.attributes or {}) entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')]) labels = ':'.join(self.labels + ['Entity'])
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue

View file

@ -325,9 +325,7 @@ async def node_search(
search_tasks = [] search_tasks = []
if NodeSearchMethod.bm25 in config.search_methods: if NodeSearchMethod.bm25 in config.search_methods:
search_tasks.append( search_tasks.append(
node_fulltext_search( node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
)
) )
if NodeSearchMethod.cosine_similarity in config.search_methods: if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append( search_tasks.append(
@ -338,7 +336,6 @@ async def node_search(
group_ids, group_ids,
2 * limit, 2 * limit,
config.sim_min_score, config.sim_min_score,
config.use_local_indexes,
) )
) )
if NodeSearchMethod.bfs in config.search_methods: if NodeSearchMethod.bfs in config.search_methods:
@ -434,9 +431,7 @@ async def episode_search(
search_results: list[list[EpisodicNode]] = list( search_results: list[list[EpisodicNode]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
episode_fulltext_search( episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
),
] ]
) )
) )

View file

@ -24,7 +24,6 @@ from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE, DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA, DEFAULT_MMR_LAMBDA,
MAX_SEARCH_DEPTH, MAX_SEARCH_DEPTH,
USE_HNSW,
) )
DEFAULT_SEARCH_LIMIT = 10 DEFAULT_SEARCH_LIMIT = 10
@ -92,7 +91,6 @@ class NodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class EpisodeSearchConfig(BaseModel): class EpisodeSearchConfig(BaseModel):
@ -101,7 +99,6 @@ class EpisodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class CommunitySearchConfig(BaseModel): class CommunitySearchConfig(BaseModel):
@ -110,7 +107,6 @@ class CommunitySearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class SearchConfig(BaseModel): class SearchConfig(BaseModel):

View file

@ -15,7 +15,6 @@ limitations under the License.
""" """
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from time import time from time import time
from typing import Any from typing import Any
@ -57,7 +56,6 @@ from graphiti_core.search.search_filters import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
RELEVANT_SCHEMA_LIMIT = 10 RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6 DEFAULT_MIN_SCORE = 0.6
@ -210,11 +208,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -320,8 +318,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -540,7 +538,6 @@ async def node_fulltext_search(
search_filter: SearchFilters, search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EntityNode]: ) -> list[EntityNode]:
# BM25 search to get top nodes # BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver) fuzzy_query = fulltext_query(query, group_ids, driver)
@ -574,11 +571,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -596,14 +593,10 @@ async def node_fulltext_search(
else: else:
return [] return []
else: else:
index_name = (
'node_name_and_summary'
if not use_local_indexes
else 'node_name_and_summary_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = ( query = (
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider) get_nodes_query(
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
)
+ yield_query + yield_query
+ filter_query + filter_query
+ """ + """
@ -635,7 +628,6 @@ async def node_similarity_search(
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
use_local_indexes: bool = False,
) -> list[EntityNode]: ) -> list[EntityNode]:
filter_queries, filter_params = node_search_filter_query_constructor( filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider search_filter, driver.provider
@ -656,8 +648,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -686,11 +678,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -708,40 +700,11 @@ async def node_similarity_search(
) )
else: else:
return [] return []
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
index_name = 'group_entity_vector_' + (
group_ids[0].replace('-', '') if group_ids is not None else ''
)
query = (
f"""
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
"""
+ filter_query
+ """
AND score > $min_score
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -865,7 +828,6 @@ async def episode_fulltext_search(
_search_filter: SearchFilters, _search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
# BM25 search to get top episodes # BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver) fuzzy_query = fulltext_query(query, group_ids, driver)
@ -915,14 +877,8 @@ async def episode_fulltext_search(
else: else:
return [] return []
else: else:
index_name = (
'episode_content'
if not use_local_indexes
else 'episode_content_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = ( query = (
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider) get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
+ """ + """
YIELD node AS episode, score YIELD node AS episode, score
MATCH (e:Episodic) MATCH (e:Episodic)
@ -1047,8 +1003,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1107,8 +1063,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1250,9 +1206,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1297,9 +1253,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1388,9 +1344,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1460,9 +1416,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1498,9 +1454,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1573,10 +1529,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1646,10 +1602,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1685,10 +1641,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -119,8 +119,6 @@ async def add_nodes_and_edges_bulk_tx(
for episode in episodes: for episode in episodes:
episode['source'] = str(episode['source'].value) episode['source'] = str(episode['source'].value)
episode.pop('labels', None) episode.pop('labels', None)
if driver.provider == GraphProvider.NEO4J:
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes = [] nodes = []
@ -143,9 +141,6 @@ async def add_nodes_and_edges_bulk_tx(
entity_data['attributes'] = json.dumps(attributes) entity_data['attributes'] = json.dumps(attributes)
else: else:
entity_data.update(node.attributes or {}) entity_data.update(node.attributes or {})
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data) nodes.append(entity_data)

View file

@ -149,9 +149,9 @@ async def retrieve_episodes(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time WHERE e.valid_at <= $reference_time
""" """
+ query_filter + query_filter
+ """ + """
RETURN RETURN
@ -175,44 +175,3 @@ async def retrieve_episodes(
episodes = [get_episodic_node_from_record(record) for record in result] episodes = [get_episodic_node_from_record(record) for record in result]
return list(reversed(episodes)) # Return in chronological order return list(reversed(episodes)) # Return in chronological order
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
# Make sure indices exist for this group_id in Neo4j
if driver.provider == GraphProvider.NEO4J:
await driver.execute_query(
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
FOR (e:"""
+ 'Episodic_'
+ group_id.replace('-', '')
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
episode_content='episode_content_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.summary, n.group_id]""",
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
FOR (n:"""
+ 'Community_'
+ group_id.replace('-', '')
+ """) ON EACH [n.name, n.group_id]""",
community_name='Community_' + group_id.replace('-', ''),
)
await driver.execute_query(
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
FOR (n:"""
+ 'Entity_'
+ group_id.replace('-', '')
+ """)
ON n.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}""",
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
)

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.20.1" version = "0.20.2"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.20.0" version = "0.20.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },