async update

This commit is contained in:
prestonrasmussen 2025-09-14 00:53:38 -04:00
parent 42bbd93c38
commit e0066ff235
7 changed files with 117 additions and 89 deletions

View file

@ -25,6 +25,7 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@ -34,6 +35,8 @@ load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
aoss_host = os.environ.get('AOSS_HOST') or None
aoss_port = os.environ.get('AOSS_PORT') or None
def setup_logging(): def setup_logging():
@ -77,12 +80,25 @@ class IsPresidentOf(BaseModel):
async def main(use_bulk: bool = False): async def main(use_bulk: bool = False):
setup_logging() setup_logging()
client = Graphiti( graph_driver = Neo4jDriver(
neo4j_uri, neo4j_uri,
neo4j_user, neo4j_user,
neo4j_password, neo4j_password,
aoss_host=aoss_host,
aoss_port=int(aoss_port),
aws_profile_name='zep-development',
aws_region='us-west-2',
aws_service='es',
) )
# client = Graphiti(
# neo4j_uri,
# neo4j_user,
# neo4j_password,
# )
client = Graphiti(graph_driver=graph_driver)
await clear_data(client.driver) await clear_data(client.driver)
await client.driver.delete_aoss_indices()
await client.driver.create_aoss_indices()
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
messages = parse_podcast_messages() messages = parse_podcast_messages()
group_id = str(uuid4()) group_id = str(uuid4())

View file

@ -24,10 +24,12 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from dotenv import load_dotenv
from graphiti_core.embedder.client import EMBEDDING_DIM from graphiti_core.embedder.client import EMBEDDING_DIM
try: try:
from opensearchpy import OpenSearch, helpers from opensearchpy import AsyncOpenSearch, helpers
_HAS_OPENSEARCH = True _HAS_OPENSEARCH = True
except ImportError: except ImportError:
@ -39,6 +41,8 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
load_dotenv()
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities') ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes') EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
@ -62,7 +66,7 @@ aoss_indices = [
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'summary': {'type': 'text'}, 'summary': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'name_embedding': { 'name_embedding': {
'type': 'knn_vector', 'type': 'knn_vector',
@ -85,7 +89,7 @@ aoss_indices = [
'properties': { 'properties': {
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
} }
} }
}, },
@ -99,7 +103,7 @@ aoss_indices = [
'content': {'type': 'text'}, 'content': {'type': 'text'},
'source': {'type': 'text'}, 'source': {'type': 'text'},
'source_description': {'type': 'text'}, 'source_description': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
} }
@ -115,7 +119,7 @@ aoss_indices = [
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'fact': {'type': 'text'}, 'fact': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
@ -167,7 +171,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries '' # Neo4j (default) syntax does not require a prefix for fulltext queries
) )
_database: str _database: str
aoss_client: OpenSearch | None # type: ignore aoss_client: AsyncOpenSearch | None # type: ignore
@abstractmethod @abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -209,7 +213,7 @@ class GraphDriver(ABC):
alias_name = index['index_name'] alias_name = index['index_name']
# If alias already exists, skip (idempotent behavior) # If alias already exists, skip (idempotent behavior)
if client.indices.exists_alias(name=alias_name): if await client.indices.exists_alias(name=alias_name):
continue continue
# Build a physical index name with timestamp # Build a physical index name with timestamp
@ -217,10 +221,10 @@ class GraphDriver(ABC):
physical_index_name = f'{alias_name}_{ts_suffix}' physical_index_name = f'{alias_name}_{ts_suffix}'
# Create the index # Create the index
client.indices.create(index=physical_index_name, body=index['body']) await client.indices.create(index=physical_index_name, body=index['body'])
# Point alias to it # Point alias to it
client.indices.put_alias(index=physical_index_name, name=alias_name) await client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation # Allow some time for index creation
await asyncio.sleep(1) await asyncio.sleep(1)
@ -237,7 +241,7 @@ class GraphDriver(ABC):
try: try:
# Resolve alias → indices # Resolve alias → indices
alias_info = client.indices.get_alias(name=alias_name) alias_info = await client.indices.get_alias(name=alias_name)
indices = list(alias_info.keys()) indices = list(alias_info.keys())
if not indices: if not indices:
@ -245,8 +249,8 @@ class GraphDriver(ABC):
continue continue
for index in indices: for index in indices:
if client.indices.exists(index=index): if await client.indices.exists(index=index):
client.indices.delete(index=index) await client.indices.delete(index=index)
logger.info(f"Deleted index '{index}' (alias: {alias_name})") logger.info(f"Deleted index '{index}' (alias: {alias_name})")
else: else:
logger.warning(f"Index '{index}' not found for alias '{alias_name}'") logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
@ -264,14 +268,16 @@ class GraphDriver(ABC):
for index in aoss_indices: for index in aoss_indices:
index_name = index['index_name'] index_name = index['index_name']
if client.indices.exists(index=index_name): if await client.indices.exists(index=index_name):
try: try:
# Delete all documents but keep the index # Delete all documents but keep the index
response = client.delete_by_query( response = await client.delete_by_query(
index=index_name, index=index_name,
body={'query': {'match_all': {}}}, body={'query': {'match_all': {}}},
refresh=True, refresh=True,
conflicts='proceed', conflicts='proceed',
wait_for_completion=True,
slices='auto', # improves coverage/concurrency
) )
logger.info(f"Cleared index '{index_name}': {response}") logger.info(f"Cleared index '{index_name}': {response}")
except Exception as e: except Exception as e:
@ -281,7 +287,7 @@ class GraphDriver(ABC):
async def save_to_aoss(self, name: str, data: list[dict]) -> int: async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client client = self.aoss_client
if not client or not helpers: if not client:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')
return 0 return 0
@ -289,16 +295,20 @@ class GraphDriver(ABC):
if name.lower() == index['index_name']: if name.lower() == index['index_name']:
to_index = [] to_index = []
for d in data: for d in data:
item = { doc = {}
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
for p in index['body']['mappings']['properties']: for p in index['body']['mappings']['properties']:
if p in d: # protect against missing fields if p in d: # protect against missing fields
item[p] = d[p] doc[p] = d[p]
item = {
'_index': name,
'_id': d['uuid'],
'_routing': d.get('group_id'),
'_source': doc,
}
to_index.append(item) to_index.append(item)
success, failed = helpers.bulk( success, failed = await helpers.async_bulk(
client, to_index, stats_only=True, request_timeout=60 client, to_index, stats_only=True, request_timeout=60
) )

View file

@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
try: try:
import boto3 import boto3
from opensearchpy import ( from opensearchpy import (
AIOHttpConnection,
AsyncOpenSearch,
AWSV4SignerAuth, AWSV4SignerAuth,
OpenSearch,
RequestsHttpConnection, RequestsHttpConnection,
Urllib3AWSV4SignerAuth, Urllib3AWSV4SignerAuth,
Urllib3HttpConnection, Urllib3HttpConnection,
@ -75,12 +76,12 @@ class Neo4jDriver(GraphDriver):
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
auth = AWSV4SignerAuth(credentials, region or '', service or '') auth = AWSV4SignerAuth(credentials, region or '', service or '')
self.aoss_client = OpenSearch( self.aoss_client = AsyncOpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}], hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=auth, auth=auth,
use_ssl=True, use_ssl=True,
verify_certs=True, verify_certs=True,
connection_class=RequestsHttpConnection, connection_class=AIOHttpConnection,
pool_maxsize=20, pool_maxsize=20,
) # type: ignore ) # type: ignore
except Exception as e: except Exception as e:

View file

@ -267,7 +267,7 @@ class EntityEdge(Edge):
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
""" """
elif driver.aoss_client: elif driver.aoss_client:
resp = driver.aoss_client.search( resp = await driver.aoss_client.search(
body={ body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,

View file

@ -513,7 +513,7 @@ class EntityNode(Node):
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
""" """
elif driver.aoss_client: elif driver.aoss_client:
resp = driver.aoss_client.search( resp = await driver.aoss_client.search(
body={ body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,

View file

@ -215,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -254,7 +254,7 @@ async def edge_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -353,8 +353,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -415,7 +415,7 @@ async def edge_similarity_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -637,11 +637,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -661,7 +661,7 @@ async def node_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -674,7 +674,7 @@ async def node_fulltext_search(
{ {
'multi_match': { 'multi_match': {
'query': query, 'query': query,
'fields': ['name', 'summary'], # ✅ fixed key 'fields': ['name', 'summary'],
'operator': 'or', 'operator': 'or',
} }
} }
@ -751,8 +751,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -781,11 +781,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -806,7 +806,7 @@ async def node_similarity_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -837,8 +837,8 @@ async def node_similarity_search(
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -1011,7 +1011,7 @@ async def episode_fulltext_search(
return [] return []
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
res = driver.aoss_client.search( res = await driver.aoss_client.search(
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -1170,8 +1170,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1230,8 +1230,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1373,9 +1373,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1420,9 +1420,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1511,9 +1511,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1583,9 +1583,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1621,9 +1621,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -17,6 +17,7 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from xml.dom.minidom import Entity
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import LiteralString from typing_extensions import LiteralString
@ -260,7 +261,7 @@ async def resolve_extracted_edges(
embedder = clients.embedder embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges) await create_entity_edge_embeddings(embedder, extracted_edges)
valid_uuids_list: list[list[str]] = await semaphore_gather( valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[ *[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges for edge in extracted_edges
@ -274,9 +275,9 @@ async def resolve_extracted_edges(
extracted_edge.fact, extracted_edge.fact,
group_ids=[extracted_edge.group_id], group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=valid_uuids), search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
) )
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
] ]
) )