This commit is contained in:
prestonrasmussen 2025-09-12 11:52:35 -04:00
parent 2f8b25acf7
commit 28d1adf631
7 changed files with 199 additions and 155 deletions

View file

@ -80,22 +80,11 @@ class IsPresidentOf(BaseModel):
async def main(use_bulk: bool = False): async def main(use_bulk: bool = False):
setup_logging() setup_logging()
graph_driver = Neo4jDriver( client = Graphiti(
neo4j_uri, neo4j_uri,
neo4j_user, neo4j_user,
neo4j_password, neo4j_password,
aoss_host=aoss_host,
aoss_port=aoss_port,
region='us-west-2',
service='es',
) )
# client = Graphiti(
# neo4j_uri,
# neo4j_user,
# neo4j_password,
# )
client = Graphiti(graph_driver=graph_driver)
await client.driver.create_aoss_indices()
await clear_data(client.driver) await clear_data(client.driver)
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
messages = parse_podcast_messages() messages = parse_podcast_messages()

View file

@ -38,6 +38,11 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
EPISODE_INDEX_NAME = 'episodes-test'
ENTTITY_INDEX_NAME = 'entities_test'
COMMUNITY_INDEX_NAME = 'communities-test'
ENTITY_EDGE_INDEX_NAME = 'entity_edges_test'
class GraphProvider(Enum): class GraphProvider(Enum):
NEO4J = 'neo4j' NEO4J = 'neo4j'
@ -48,20 +53,19 @@ class GraphProvider(Enum):
aoss_indices = [ aoss_indices = [
{ {
'index_name': 'entities_test', 'index_name': ENTTITY_INDEX_NAME,
'body': { 'body': {
'settings': {'index': {'knn': True}},
'mappings': { 'mappings': {
'properties': { 'properties': {
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'summary': {'type': 'text'}, 'summary': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'name_embedding': { 'name_embedding': {
'type': 'knn_vector', 'type': 'knn_vector',
'dims': EMBEDDING_DIM, 'dimension': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': { 'method': {
'engine': 'faiss', 'engine': 'faiss',
'space_type': 'cosinesimil', 'space_type': 'cosinesimil',
@ -70,11 +74,11 @@ aoss_indices = [
}, },
}, },
} }
} },
}, },
}, },
{ {
'index_name': 'communities_test', 'index_name': COMMUNITY_INDEX_NAME,
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -86,7 +90,7 @@ aoss_indices = [
}, },
}, },
{ {
'index_name': 'episodes', 'index_name': EPISODE_INDEX_NAME,
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -95,30 +99,29 @@ aoss_indices = [
'source': {'type': 'text'}, 'source': {'type': 'text'},
'source_description': {'type': 'text'}, 'source_description': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
} }
} }
}, },
}, },
{ {
'index_name': 'entity_edges_test', 'index_name': ENTITY_EDGE_INDEX_NAME,
'body': { 'body': {
'settings': {'index': {'knn': True}},
'mappings': { 'mappings': {
'properties': { 'properties': {
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'fact': {'type': 'text'}, 'fact': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'fact_embedding': { 'fact_embedding': {
'type': 'knn_vector', 'type': 'knn_vector',
'dims': EMBEDDING_DIM, 'dimension': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': { 'method': {
'engine': 'faiss', 'engine': 'faiss',
'space_type': 'cosinesimil', 'space_type': 'cosinesimil',
@ -127,7 +130,7 @@ aoss_indices = [
}, },
}, },
} }
} },
}, },
}, },
] ]
@ -219,19 +222,36 @@ class GraphDriver(ABC):
client.indices.put_alias(index=physical_index_name, name=alias_name) client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation # Allow some time for index creation
await asyncio.sleep(60) await asyncio.sleep(1)
async def delete_aoss_indices(self): async def delete_aoss_indices(self):
for index in aoss_indices: client = self.aoss_client
index_name = index['index_name']
client = self.aoss_client
if not client: if not client:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')
return return
if client.indices.exists(index=index_name): for entry in aoss_indices:
client.indices.delete(index=index_name) alias_name = entry['index_name']
try:
# Resolve alias → indices
alias_info = client.indices.get_alias(name=alias_name)
indices = list(alias_info.keys())
if not indices:
logger.info(f"No indices found for alias '{alias_name}'")
continue
for index in indices:
if client.indices.exists(index=index):
client.indices.delete(index=index)
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
else:
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
except Exception as e:
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
async def clear_aoss_indices(self): async def clear_aoss_indices(self):
client = self.aoss_client client = self.aoss_client
@ -277,7 +297,9 @@ class GraphDriver(ABC):
item[p] = d[p] item[p] = d[p]
to_index.append(item) to_index.append(item)
success, failed = helpers.bulk(client, to_index, stats_only=True) success, failed = helpers.bulk(
client, to_index, stats_only=True, request_timeout=60
)
return success if failed == 0 else success return success if failed == 0 else success

View file

@ -56,8 +56,9 @@ class Neo4jDriver(GraphDriver):
database: str = 'neo4j', database: str = 'neo4j',
aoss_host: str | None = None, aoss_host: str | None = None,
aoss_port: int | None = None, aoss_port: int | None = None,
region: str | None = None, aws_profile_name: str | None = None,
service: str | None = None, aws_region: str | None = None,
aws_service: str | None = None,
): ):
super().__init__() super().__init__()
self.client = AsyncGraphDatabase.driver( self.client = AsyncGraphDatabase.driver(
@ -69,9 +70,9 @@ class Neo4jDriver(GraphDriver):
self.aoss_client = None self.aoss_client = None
if aoss_host and aoss_port and boto3 is not None: if aoss_host and aoss_port and boto3 is not None:
try: try:
region = region region = aws_region
service = service service = aws_service
credentials = boto3.Session(profile_name='zep-development').get_credentials() credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
auth = AWSV4SignerAuth(credentials, region, service) auth = AWSV4SignerAuth(credentials, region, service)
self.aoss_client = OpenSearch( self.aoss_client = OpenSearch(

View file

@ -25,7 +25,7 @@ from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
@ -79,7 +79,7 @@ class Edge(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
await driver.aoss_client.delete( await driver.aoss_client.delete(
index='entity_edges', id=self.uuid, routing=self.group_id index=ENTITY_EDGE_INDEX_NAME, id=self.uuid, routing=self.group_id
) )
logger.debug(f'Deleted Edge: {self.uuid}') logger.debug(f'Deleted Edge: {self.uuid}')
@ -115,7 +115,7 @@ class Edge(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index='entity_edges', index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'terms': {'uuid': uuids}}}, body={'query': {'terms': {'uuid': uuids}}},
) )
@ -272,7 +272,7 @@ class EntityEdge(Edge):
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,
}, },
index='entity_edges', index=ENTITY_EDGE_INDEX_NAME,
routing=self.group_id, routing=self.group_id,
) )
@ -325,7 +325,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {}) edge_data.update(self.attributes or {})
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider),

View file

@ -26,7 +26,14 @@ from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.driver.driver import (
COMMUNITY_INDEX_NAME,
ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
@ -110,7 +117,7 @@ class Node(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
# Delete the node from OpenSearch indices # Delete the node from OpenSearch indices
for index in ('episodes', 'entities', 'communities'): for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete( await driver.aoss_client.delete(
index=index, id=self.uuid, routing=self.group_id index=index, id=self.uuid, routing=self.group_id
) )
@ -119,7 +126,9 @@ class Node(BaseModel, ABC):
if edge_uuids: if edge_uuids:
actions = [] actions = []
for eid in edge_uuids: for eid in edge_uuids:
actions.append({'delete': {'_index': 'entity_edges', '_id': eid}}) actions.append(
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
)
await driver.aoss_client.bulk(body=actions) await driver.aoss_client.bulk(body=actions)
@ -187,25 +196,25 @@ class Node(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index='episodes', index=EPISODE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, routing=group_id,
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index='entities', index=ENTTITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, routing=group_id,
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index='communities', index=COMMUNITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, routing=group_id,
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index='entity_edges', index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, routing=group_id,
) )
@ -319,7 +328,7 @@ class Node(BaseModel, ABC):
) )
if driver.aoss_client: if driver.aoss_client:
for index in ('episodes', 'entities', 'communities'): for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=index, index=index,
body={'query': {'terms': {'uuid': uuids}}}, body={'query': {'terms': {'uuid': uuids}}},
@ -327,7 +336,8 @@ class Node(BaseModel, ABC):
if edge_uuids: if edge_uuids:
actions = [ actions = [
{'delete': {'_index': 'entity_edges', '_id': eid}} for eid in edge_uuids {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
for eid in edge_uuids
] ]
await driver.aoss_client.bulk(body=actions) await driver.aoss_client.bulk(body=actions)
@ -509,7 +519,7 @@ class EntityNode(Node):
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,
}, },
index='entities', index=ENTTITY_INDEX_NAME,
routing=self.group_id, routing=self.group_id,
) )
@ -557,7 +567,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss(ENTTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),

View file

@ -23,7 +23,13 @@ import numpy as np
from numpy._typing import NDArray from numpy._typing import NDArray
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import ( from graphiti_core.graph_queries import (
get_nodes_query, get_nodes_query,
@ -209,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -249,16 +255,19 @@ async def edge_fulltext_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entity_edges', index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
query={ body={
'bool': { 'query': {
'filter': filters, 'bool': {
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], 'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
} }
}, },
) )
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
input_uuids = {} input_uuids = {}
for r in res['hits']['hits']: for r in res['hits']['hits']:
@ -344,8 +353,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -407,16 +416,17 @@ async def edge_similarity_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entity_edges', index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
knn={ size=limit,
'field': 'fact_embedding', body={
'query_vector': search_vector, 'query': {
'k': limit, 'knn': {
'num_candidates': 1000, 'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
}
}
}, },
query={'bool': {'filter': filters}},
) )
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
@ -428,6 +438,7 @@ async def edge_similarity_search(
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges return entity_edges
return []
else: else:
query = ( query = (
@ -622,11 +633,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -647,24 +658,26 @@ async def node_fulltext_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
'entities', index=ENTTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
query={ size=limit,
'bool': { body={
'filter': filters, 'query': {
'must': [ 'bool': {
{ 'filter': filters,
'multi_match': { 'must': [
'query': query, {
'field': ['name', 'summary'], 'multi_match': {
'operator': 'or', 'query': query,
'fields': ['name', 'summary'], # ✅ fixed key
'operator': 'or',
}
} }
} ],
], }
} }
}, },
limit=limit,
) )
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
@ -734,8 +747,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -764,11 +777,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -790,16 +803,17 @@ async def node_similarity_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index='entities', index=ENTTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
knn={ size=limit,
'field': 'fact_embedding', body={
'query_vector': search_vector, 'query': {
'k': limit, 'knn': {
'num_candidates': 1000, 'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
}
}
}, },
query={'bool': {'filter': filters}},
) )
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
@ -811,11 +825,12 @@ async def node_similarity_search(
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes return entity_nodes
return []
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -989,7 +1004,7 @@ async def episode_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
res = driver.aoss_client.search( res = driver.aoss_client.search(
'episodes', EPISODE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
query={ query={
@ -1147,8 +1162,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1207,8 +1222,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1350,9 +1365,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1397,9 +1412,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1488,9 +1503,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1560,9 +1575,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1599,9 +1614,9 @@ async def get_relevant_edges(
# First get edge candidates # First get edge candidates
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
RETURN RETURN
@ -1647,9 +1662,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1722,10 +1737,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1795,10 +1810,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1834,10 +1849,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -23,7 +23,14 @@ import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Any from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphDriverSession,
GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.graphiti_types import GraphitiClients
@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
) )
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss('episodes', episodes) driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
driver.save_to_aoss('entities', nodes) driver.save_to_aoss(ENTTITY_INDEX_NAME, nodes)
driver.save_to_aoss('entity_edges', edges) driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(