parent
c0fcc82ebe
commit
1f5a1b890c
11 changed files with 64 additions and 284 deletions
|
|
@ -89,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_dynamic_indexes,
|
||||
build_indices_and_constraints,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
|
@ -451,7 +450,6 @@ class Graphiti:
|
|||
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
await build_dynamic_indexes(self.driver, group_id)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
|
|
|
|||
|
|
@ -1,114 +0,0 @@
|
|||
import asyncio
|
||||
import csv
|
||||
import os
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.helpers import validate_group_id
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
|
||||
|
||||
|
||||
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||
validate_group_id(group_id)
|
||||
await build_dynamic_indexes(driver, group_id)
|
||||
|
||||
episode_query = """
|
||||
MATCH (n:Episodic {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
entity_query = """
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
community_query = """
|
||||
MATCH (n:Community {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
episode_query,
|
||||
group_id=group_id,
|
||||
group_label='Episodic_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
entity_query,
|
||||
group_id=group_id,
|
||||
group_label='Entity_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
community_query,
|
||||
group_id=group_id,
|
||||
group_label='Community_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
|
||||
with open(csv_file) as file:
|
||||
reader = csv.reader(file)
|
||||
group_ids = [row[0] for row in reader]
|
||||
|
||||
total_count = len(group_ids)
|
||||
popped = group_ids[-count:]
|
||||
remaining = group_ids[:-count]
|
||||
|
||||
with open(csv_file, 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for gid in remaining:
|
||||
writer.writerow([gid])
|
||||
|
||||
return popped, total_count
|
||||
|
||||
|
||||
async def get_group_ids(driver: GraphDriver):
|
||||
query = """MATCH (n:Episodic)
|
||||
RETURN DISTINCT n.group_id AS group_id"""
|
||||
|
||||
results, _, _ = await driver.execute_query(query)
|
||||
group_ids = [result['group_id'] for result in results]
|
||||
|
||||
with open('group_ids.csv', 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for gid in group_ids:
|
||||
writer.writerow([gid])
|
||||
|
||||
|
||||
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
|
||||
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
||||
while len(group_ids) > 0:
|
||||
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
|
||||
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
||||
|
||||
|
||||
async def main():
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
driver = Neo4jDriver(
|
||||
uri=neo4j_uri,
|
||||
user=neo4j_user,
|
||||
password=neo4j_password,
|
||||
)
|
||||
await get_group_ids(driver)
|
||||
await neo4j_node_label_migration(driver)
|
||||
await driver.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
@ -52,7 +52,6 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n:$($group_label)
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
|
|
@ -96,7 +95,6 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n:$(episode.group_label)
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
|
|
|
|||
|
|
@ -299,9 +299,6 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
|
@ -471,7 +468,7 @@ class EntityNode(Node):
|
|||
)
|
||||
else:
|
||||
entity_data.update(self.attributes or {})
|
||||
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
|
|
|||
|
|
@ -325,9 +325,7 @@ async def node_search(
|
|||
search_tasks = []
|
||||
if NodeSearchMethod.bm25 in config.search_methods:
|
||||
search_tasks.append(
|
||||
node_fulltext_search(
|
||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||
)
|
||||
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
||||
)
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_tasks.append(
|
||||
|
|
@ -338,7 +336,6 @@ async def node_search(
|
|||
group_ids,
|
||||
2 * limit,
|
||||
config.sim_min_score,
|
||||
config.use_local_indexes,
|
||||
)
|
||||
)
|
||||
if NodeSearchMethod.bfs in config.search_methods:
|
||||
|
|
@ -434,9 +431,7 @@ async def episode_search(
|
|||
search_results: list[list[EpisodicNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
episode_fulltext_search(
|
||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||
),
|
||||
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from graphiti_core.search.search_utils import (
|
|||
DEFAULT_MIN_SCORE,
|
||||
DEFAULT_MMR_LAMBDA,
|
||||
MAX_SEARCH_DEPTH,
|
||||
USE_HNSW,
|
||||
)
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 10
|
||||
|
|
@ -92,7 +91,6 @@ class NodeSearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class EpisodeSearchConfig(BaseModel):
|
||||
|
|
@ -101,7 +99,6 @@ class EpisodeSearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
|
|
@ -110,7 +107,6 @@ class CommunitySearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
|
@ -57,7 +56,6 @@ from graphiti_core.search.search_filters import (
|
|||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 10
|
||||
DEFAULT_MIN_SCORE = 0.6
|
||||
|
|
@ -210,11 +208,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -320,8 +318,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -540,7 +538,6 @@ async def node_fulltext_search(
|
|||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
|
|
@ -574,11 +571,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -596,14 +593,10 @@ async def node_fulltext_search(
|
|||
else:
|
||||
return []
|
||||
else:
|
||||
index_name = (
|
||||
'node_name_and_summary'
|
||||
if not use_local_indexes
|
||||
else 'node_name_and_summary_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
query = (
|
||||
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
||||
get_nodes_query(
|
||||
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
|
||||
)
|
||||
+ yield_query
|
||||
+ filter_query
|
||||
+ """
|
||||
|
|
@ -635,7 +628,6 @@ async def node_similarity_search(
|
|||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EntityNode]:
|
||||
filter_queries, filter_params = node_search_filter_query_constructor(
|
||||
search_filter, driver.provider
|
||||
|
|
@ -656,8 +648,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -686,11 +678,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -708,40 +700,11 @@ async def node_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
||||
index_name = 'group_entity_vector_' + (
|
||||
group_ids[0].replace('-', '') if group_ids is not None else ''
|
||||
)
|
||||
query = (
|
||||
f"""
|
||||
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND score > $min_score
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
|
||||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -865,7 +828,6 @@ async def episode_fulltext_search(
|
|||
_search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EpisodicNode]:
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
|
|
@ -915,14 +877,8 @@ async def episode_fulltext_search(
|
|||
else:
|
||||
return []
|
||||
else:
|
||||
index_name = (
|
||||
'episode_content'
|
||||
if not use_local_indexes
|
||||
else 'episode_content_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
query = (
|
||||
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
||||
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
||||
+ """
|
||||
YIELD node AS episode, score
|
||||
MATCH (e:Episodic)
|
||||
|
|
@ -1047,8 +1003,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1107,8 +1063,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1250,9 +1206,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1297,9 +1253,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1388,9 +1344,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1460,9 +1416,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1498,9 +1454,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1573,10 +1529,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1646,10 +1602,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1685,10 +1641,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -119,8 +119,6 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
episode.pop('labels', None)
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
||||
|
||||
nodes = []
|
||||
|
||||
|
|
@ -143,9 +141,6 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
entity_data['attributes'] = json.dumps(attributes)
|
||||
else:
|
||||
entity_data.update(node.attributes or {})
|
||||
entity_data['labels'] = list(
|
||||
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
||||
)
|
||||
|
||||
nodes.append(entity_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -149,9 +149,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
|
|
@ -175,44 +175,3 @@ async def retrieve_episodes(
|
|||
|
||||
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||
return list(reversed(episodes)) # Return in chronological order
|
||||
|
||||
|
||||
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
|
||||
# Make sure indices exist for this group_id in Neo4j
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
await driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
|
||||
FOR (e:"""
|
||||
+ 'Episodic_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
episode_content='episode_content_' + group_id.replace('-', ''),
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
|
||||
+ 'Entity_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
|
||||
FOR (n:"""
|
||||
+ 'Community_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """) ON EACH [n.name, n.group_id]""",
|
||||
community_name='Community_' + group_id.replace('-', ''),
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
|
||||
FOR (n:"""
|
||||
+ 'Entity_'
|
||||
+ group_id.replace('-', '')
|
||||
+ """)
|
||||
ON n.embedding
|
||||
OPTIONS { indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}""",
|
||||
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.20.1"
|
||||
version = "0.20.2"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.20.0"
|
||||
version = "0.20.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue