This commit is contained in:
prestonrasmussen 2025-09-14 01:23:13 -04:00
parent e0066ff235
commit fd1c360e8c
7 changed files with 60 additions and 80 deletions

View file

@ -25,7 +25,6 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@ -35,8 +34,6 @@ load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
aoss_host = os.environ.get('AOSS_HOST') or None
aoss_port = os.environ.get('AOSS_PORT') or None
def setup_logging(): def setup_logging():
@ -80,25 +77,12 @@ class IsPresidentOf(BaseModel):
async def main(use_bulk: bool = False): async def main(use_bulk: bool = False):
setup_logging() setup_logging()
graph_driver = Neo4jDriver( client = Graphiti(
neo4j_uri, neo4j_uri,
neo4j_user, neo4j_user,
neo4j_password, neo4j_password,
aoss_host=aoss_host,
aoss_port=int(aoss_port),
aws_profile_name='zep-development',
aws_region='us-west-2',
aws_service='es',
) )
# client = Graphiti(
# neo4j_uri,
# neo4j_user,
# neo4j_password,
# )
client = Graphiti(graph_driver=graph_driver)
await clear_data(client.driver) await clear_data(client.driver)
await client.driver.delete_aoss_indices()
await client.driver.create_aoss_indices()
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
messages = parse_podcast_messages() messages = parse_podcast_messages()
group_id = str(uuid4()) group_id = str(uuid4())

View file

@ -274,10 +274,7 @@ class GraphDriver(ABC):
response = await client.delete_by_query( response = await client.delete_by_query(
index=index_name, index=index_name,
body={'query': {'match_all': {}}}, body={'query': {'match_all': {}}},
refresh=True, slices='auto',
conflicts='proceed',
wait_for_completion=True,
slices='auto', # improves coverage/concurrency
) )
logger.info(f"Cleared index '{index_name}': {response}") logger.info(f"Cleared index '{index_name}': {response}")
except Exception as e: except Exception as e:
@ -287,7 +284,7 @@ class GraphDriver(ABC):
async def save_to_aoss(self, name: str, data: list[dict]) -> int: async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client client = self.aoss_client
if not client: if not client or not helpers:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')
return 0 return 0

View file

@ -32,7 +32,6 @@ try:
AIOHttpConnection, AIOHttpConnection,
AsyncOpenSearch, AsyncOpenSearch,
AWSV4SignerAuth, AWSV4SignerAuth,
RequestsHttpConnection,
Urllib3AWSV4SignerAuth, Urllib3AWSV4SignerAuth,
Urllib3HttpConnection, Urllib3HttpConnection,
) )

View file

@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
'You must provide an AOSS endpoint to create an OpenSearch driver.' 'You must provide an AOSS endpoint to create an OpenSearch driver.'
) )
if not client.indices.exists(index=index_name): if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body']) await client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name) alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name): if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name) await client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete # Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60) await asyncio.sleep(60)

View file

@ -103,8 +103,9 @@ class Node(BaseModel, ABC):
case GraphProvider.NEO4J: case GraphProvider.NEO4J:
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n {uuid: $uuid})-[r]-() MATCH (n {uuid: $uuid})
WHERE n:Entity OR n:Episodic OR n:Community WHERE n:Entity OR n:Episodic OR n:Community
OPTIONAL MATCH (n)-[r]-()
WITH collect(r.uuid) AS edge_uuids, n WITH collect(r.uuid) AS edge_uuids, n
DETACH DELETE n DETACH DELETE n
RETURN edge_uuids RETURN edge_uuids

View file

@ -215,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -353,8 +353,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -637,11 +637,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -751,8 +751,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -781,11 +781,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -808,7 +808,7 @@ async def node_similarity_search(
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'], _source=['uuid'],
size=limit, size=limit,
body={ body={
@ -837,8 +837,8 @@ async def node_similarity_search(
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -1170,8 +1170,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1230,8 +1230,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1373,9 +1373,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1420,9 +1420,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1511,9 +1511,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1583,9 +1583,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1621,9 +1621,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -17,7 +17,6 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from xml.dom.minidom import Entity
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import LiteralString from typing_extensions import LiteralString