parent
c0fcc82ebe
commit
1f5a1b890c
11 changed files with 64 additions and 284 deletions
|
|
@ -89,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
build_dynamic_indexes,
|
|
||||||
build_indices_and_constraints,
|
build_indices_and_constraints,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
|
|
@ -451,7 +450,6 @@ class Graphiti:
|
||||||
|
|
||||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||||
validate_group_id(group_id)
|
validate_group_id(group_id)
|
||||||
await build_dynamic_indexes(self.driver, group_id)
|
|
||||||
|
|
||||||
previous_episodes = (
|
previous_episodes = (
|
||||||
await self.retrieve_episodes(
|
await self.retrieve_episodes(
|
||||||
|
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
||||||
from graphiti_core.helpers import validate_group_id
|
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
|
|
||||||
|
|
||||||
|
|
||||||
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
||||||
validate_group_id(group_id)
|
|
||||||
await build_dynamic_indexes(driver, group_id)
|
|
||||||
|
|
||||||
episode_query = """
|
|
||||||
MATCH (n:Episodic {group_id: $group_id})
|
|
||||||
CALL {
|
|
||||||
WITH n
|
|
||||||
SET n:$($group_label)
|
|
||||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
||||||
|
|
||||||
entity_query = """
|
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
|
||||||
CALL {
|
|
||||||
WITH n
|
|
||||||
SET n:$($group_label)
|
|
||||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
||||||
|
|
||||||
community_query = """
|
|
||||||
MATCH (n:Community {group_id: $group_id})
|
|
||||||
CALL {
|
|
||||||
WITH n
|
|
||||||
SET n:$($group_label)
|
|
||||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
||||||
|
|
||||||
async with driver.session() as session:
|
|
||||||
await session.run(
|
|
||||||
episode_query,
|
|
||||||
group_id=group_id,
|
|
||||||
group_label='Episodic_' + group_id.replace('-', ''),
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with driver.session() as session:
|
|
||||||
await session.run(
|
|
||||||
entity_query,
|
|
||||||
group_id=group_id,
|
|
||||||
group_label='Entity_' + group_id.replace('-', ''),
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with driver.session() as session:
|
|
||||||
await session.run(
|
|
||||||
community_query,
|
|
||||||
group_id=group_id,
|
|
||||||
group_label='Community_' + group_id.replace('-', ''),
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
|
|
||||||
with open(csv_file) as file:
|
|
||||||
reader = csv.reader(file)
|
|
||||||
group_ids = [row[0] for row in reader]
|
|
||||||
|
|
||||||
total_count = len(group_ids)
|
|
||||||
popped = group_ids[-count:]
|
|
||||||
remaining = group_ids[:-count]
|
|
||||||
|
|
||||||
with open(csv_file, 'w', newline='') as file:
|
|
||||||
writer = csv.writer(file)
|
|
||||||
for gid in remaining:
|
|
||||||
writer.writerow([gid])
|
|
||||||
|
|
||||||
return popped, total_count
|
|
||||||
|
|
||||||
|
|
||||||
async def get_group_ids(driver: GraphDriver):
|
|
||||||
query = """MATCH (n:Episodic)
|
|
||||||
RETURN DISTINCT n.group_id AS group_id"""
|
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(query)
|
|
||||||
group_ids = [result['group_id'] for result in results]
|
|
||||||
|
|
||||||
with open('group_ids.csv', 'w', newline='') as file:
|
|
||||||
writer = csv.writer(file)
|
|
||||||
for gid in group_ids:
|
|
||||||
writer.writerow([gid])
|
|
||||||
|
|
||||||
|
|
||||||
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
|
|
||||||
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
|
||||||
while len(group_ids) > 0:
|
|
||||||
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
|
|
||||||
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
|
||||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
|
||||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
|
||||||
|
|
||||||
driver = Neo4jDriver(
|
|
||||||
uri=neo4j_uri,
|
|
||||||
user=neo4j_user,
|
|
||||||
password=neo4j_password,
|
|
||||||
)
|
|
||||||
await get_group_ids(driver)
|
|
||||||
await neo4j_node_label_migration(driver)
|
|
||||||
await driver.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
@ -52,7 +52,6 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n:$($group_label)
|
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
|
|
@ -96,7 +95,6 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
return """
|
return """
|
||||||
UNWIND $episodes AS episode
|
UNWIND $episodes AS episode
|
||||||
MERGE (n:Episodic {uuid: episode.uuid})
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
SET n:$(episode.group_label)
|
|
||||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
|
|
|
||||||
|
|
@ -299,9 +299,6 @@ class EpisodicNode(Node):
|
||||||
'source': self.source.value,
|
'source': self.source.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEO4J:
|
|
||||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_episode_node_save_query(driver.provider), **episode_args
|
get_episode_node_save_query(driver.provider), **episode_args
|
||||||
)
|
)
|
||||||
|
|
@ -471,7 +468,7 @@ class EntityNode(Node):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
entity_data.update(self.attributes or {})
|
entity_data.update(self.attributes or {})
|
||||||
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
|
||||||
|
|
@ -325,9 +325,7 @@ async def node_search(
|
||||||
search_tasks = []
|
search_tasks = []
|
||||||
if NodeSearchMethod.bm25 in config.search_methods:
|
if NodeSearchMethod.bm25 in config.search_methods:
|
||||||
search_tasks.append(
|
search_tasks.append(
|
||||||
node_fulltext_search(
|
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
||||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
search_tasks.append(
|
search_tasks.append(
|
||||||
|
|
@ -338,7 +336,6 @@ async def node_search(
|
||||||
group_ids,
|
group_ids,
|
||||||
2 * limit,
|
2 * limit,
|
||||||
config.sim_min_score,
|
config.sim_min_score,
|
||||||
config.use_local_indexes,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if NodeSearchMethod.bfs in config.search_methods:
|
if NodeSearchMethod.bfs in config.search_methods:
|
||||||
|
|
@ -434,9 +431,7 @@ async def episode_search(
|
||||||
search_results: list[list[EpisodicNode]] = list(
|
search_results: list[list[EpisodicNode]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
episode_fulltext_search(
|
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ from graphiti_core.search.search_utils import (
|
||||||
DEFAULT_MIN_SCORE,
|
DEFAULT_MIN_SCORE,
|
||||||
DEFAULT_MMR_LAMBDA,
|
DEFAULT_MMR_LAMBDA,
|
||||||
MAX_SEARCH_DEPTH,
|
MAX_SEARCH_DEPTH,
|
||||||
USE_HNSW,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 10
|
DEFAULT_SEARCH_LIMIT = 10
|
||||||
|
|
@ -92,7 +91,6 @@ class NodeSearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSearchConfig(BaseModel):
|
class EpisodeSearchConfig(BaseModel):
|
||||||
|
|
@ -101,7 +99,6 @@ class EpisodeSearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchConfig(BaseModel):
|
class CommunitySearchConfig(BaseModel):
|
||||||
|
|
@ -110,7 +107,6 @@ class CommunitySearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -57,7 +56,6 @@ from graphiti_core.search.search_filters import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
|
|
||||||
|
|
||||||
RELEVANT_SCHEMA_LIMIT = 10
|
RELEVANT_SCHEMA_LIMIT = 10
|
||||||
DEFAULT_MIN_SCORE = 0.6
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
|
|
@ -210,11 +208,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -320,8 +318,8 @@ async def edge_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||||
|
|
@ -540,7 +538,6 @@ async def node_fulltext_search(
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
use_local_indexes: bool = False,
|
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||||
|
|
@ -574,11 +571,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -596,14 +593,10 @@ async def node_fulltext_search(
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
index_name = (
|
|
||||||
'node_name_and_summary'
|
|
||||||
if not use_local_indexes
|
|
||||||
else 'node_name_and_summary_'
|
|
||||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
||||||
)
|
|
||||||
query = (
|
query = (
|
||||||
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
get_nodes_query(
|
||||||
|
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
|
||||||
|
)
|
||||||
+ yield_query
|
+ yield_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -635,7 +628,6 @@ async def node_similarity_search(
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
use_local_indexes: bool = False,
|
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
filter_queries, filter_params = node_search_filter_query_constructor(
|
filter_queries, filter_params = node_search_filter_query_constructor(
|
||||||
search_filter, driver.provider
|
search_filter, driver.provider
|
||||||
|
|
@ -656,8 +648,8 @@ async def node_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -686,11 +678,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -708,40 +700,11 @@ async def node_similarity_search(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
|
||||||
index_name = 'group_entity_vector_' + (
|
|
||||||
group_ids[0].replace('-', '') if group_ids is not None else ''
|
|
||||||
)
|
|
||||||
query = (
|
|
||||||
f"""
|
|
||||||
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
|
||||||
"""
|
|
||||||
+ filter_query
|
|
||||||
+ """
|
|
||||||
AND score > $min_score
|
|
||||||
RETURN
|
|
||||||
"""
|
|
||||||
+ get_entity_node_return_query(driver.provider)
|
|
||||||
+ """
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
search_vector=search_vector,
|
|
||||||
limit=limit,
|
|
||||||
min_score=min_score,
|
|
||||||
routing_='r',
|
|
||||||
**filter_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, """
|
WITH n, """
|
||||||
|
|
@ -865,7 +828,6 @@ async def episode_fulltext_search(
|
||||||
_search_filter: SearchFilters,
|
_search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
use_local_indexes: bool = False,
|
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
# BM25 search to get top episodes
|
# BM25 search to get top episodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||||
|
|
@ -915,14 +877,8 @@ async def episode_fulltext_search(
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
index_name = (
|
|
||||||
'episode_content'
|
|
||||||
if not use_local_indexes
|
|
||||||
else 'episode_content_'
|
|
||||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
||||||
)
|
|
||||||
query = (
|
query = (
|
||||||
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
||||||
+ """
|
+ """
|
||||||
YIELD node AS episode, score
|
YIELD node AS episode, score
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
|
|
@ -1047,8 +1003,8 @@ async def community_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -1107,8 +1063,8 @@ async def community_similarity_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH c,
|
WITH c,
|
||||||
|
|
@ -1250,9 +1206,9 @@ async def get_relevant_nodes(
|
||||||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1297,9 +1253,9 @@ async def get_relevant_nodes(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1388,9 +1344,9 @@ async def get_relevant_edges(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1460,9 +1416,9 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, n, m, """
|
WITH e, edge, n, m, """
|
||||||
|
|
@ -1498,9 +1454,9 @@ async def get_relevant_edges(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, """
|
WITH e, edge, """
|
||||||
|
|
@ -1573,10 +1529,10 @@ async def get_edge_invalidation_candidates(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1646,10 +1602,10 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, n, m, """
|
WITH edge, e, n, m, """
|
||||||
|
|
@ -1685,10 +1641,10 @@ async def get_edge_invalidation_candidates(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, """
|
WITH edge, e, """
|
||||||
|
|
|
||||||
|
|
@ -119,8 +119,6 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
for episode in episodes:
|
for episode in episodes:
|
||||||
episode['source'] = str(episode['source'].value)
|
episode['source'] = str(episode['source'].value)
|
||||||
episode.pop('labels', None)
|
episode.pop('labels', None)
|
||||||
if driver.provider == GraphProvider.NEO4J:
|
|
||||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
|
|
||||||
|
|
@ -143,9 +141,6 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
entity_data['attributes'] = json.dumps(attributes)
|
entity_data['attributes'] = json.dumps(attributes)
|
||||||
else:
|
else:
|
||||||
entity_data.update(node.attributes or {})
|
entity_data.update(node.attributes or {})
|
||||||
entity_data['labels'] = list(
|
|
||||||
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes.append(entity_data)
|
nodes.append(entity_data)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,9 +149,9 @@ async def retrieve_episodes(
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
WHERE e.valid_at <= $reference_time
|
WHERE e.valid_at <= $reference_time
|
||||||
"""
|
"""
|
||||||
+ query_filter
|
+ query_filter
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -175,44 +175,3 @@ async def retrieve_episodes(
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in result]
|
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||||
return list(reversed(episodes)) # Return in chronological order
|
return list(reversed(episodes)) # Return in chronological order
|
||||||
|
|
||||||
|
|
||||||
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
|
|
||||||
# Make sure indices exist for this group_id in Neo4j
|
|
||||||
if driver.provider == GraphProvider.NEO4J:
|
|
||||||
await driver.execute_query(
|
|
||||||
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
|
|
||||||
FOR (e:"""
|
|
||||||
+ 'Episodic_'
|
|
||||||
+ group_id.replace('-', '')
|
|
||||||
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
||||||
episode_content='episode_content_' + group_id.replace('-', ''),
|
|
||||||
)
|
|
||||||
await driver.execute_query(
|
|
||||||
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
|
|
||||||
+ 'Entity_'
|
|
||||||
+ group_id.replace('-', '')
|
|
||||||
+ """) ON EACH [n.name, n.summary, n.group_id]""",
|
|
||||||
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
|
|
||||||
)
|
|
||||||
await driver.execute_query(
|
|
||||||
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
|
|
||||||
FOR (n:"""
|
|
||||||
+ 'Community_'
|
|
||||||
+ group_id.replace('-', '')
|
|
||||||
+ """) ON EACH [n.name, n.group_id]""",
|
|
||||||
community_name='Community_' + group_id.replace('-', ''),
|
|
||||||
)
|
|
||||||
await driver.execute_query(
|
|
||||||
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
|
|
||||||
FOR (n:"""
|
|
||||||
+ 'Entity_'
|
|
||||||
+ group_id.replace('-', '')
|
|
||||||
+ """)
|
|
||||||
ON n.embedding
|
|
||||||
OPTIONS { indexConfig: {
|
|
||||||
`vector.dimensions`: 1024,
|
|
||||||
`vector.similarity_function`: 'cosine'
|
|
||||||
}}""",
|
|
||||||
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.20.1"
|
version = "0.20.2"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.20.0"
|
version = "0.20.2"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue