This commit is contained in:
prestonrasmussen 2025-09-14 01:36:08 -04:00
parent fd1c360e8c
commit d7ae1c92b4
3 changed files with 83 additions and 77 deletions

View file

@ -79,7 +79,9 @@ class Edge(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
await driver.aoss_client.delete( await driver.aoss_client.delete(
index=ENTITY_EDGE_INDEX_NAME, id=self.uuid, routing=self.group_id index=ENTITY_EDGE_INDEX_NAME,
id=self.uuid,
params={'routing': self.group_id},
) )
logger.debug(f'Deleted Edge: {self.uuid}') logger.debug(f'Deleted Edge: {self.uuid}')
@ -273,7 +275,7 @@ class EntityEdge(Edge):
'size': 1, 'size': 1,
}, },
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=self.group_id, params={'routing': self.group_id},
) )
if resp['hits']['hits']: if resp['hits']['hits']:

View file

@ -119,7 +119,9 @@ class Node(BaseModel, ABC):
# Delete the node from OpenSearch indices # Delete the node from OpenSearch indices
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete( await driver.aoss_client.delete(
index=index, id=self.uuid, routing=self.group_id index=index,
id=self.uuid,
params={'routing': self.group_id},
) )
# Bulk delete the detached edges # Bulk delete the detached edges
@ -198,25 +200,25 @@ class Node(BaseModel, ABC):
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=EPISODE_INDEX_NAME, index=EPISODE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, params={'routing': group_id},
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, params={'routing': group_id},
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=COMMUNITY_INDEX_NAME, index=COMMUNITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, params={'routing': group_id},
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, params={'routing': group_id},
) )
case GraphProvider.KUZU: case GraphProvider.KUZU:
@ -520,7 +522,7 @@ class EntityNode(Node):
'size': 1, 'size': 1,
}, },
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=self.group_id, params={'routing': self.group_id},
) )
if resp['hits']['hits']: if resp['hits']['hits']:

View file

@ -215,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -256,15 +256,16 @@ async def edge_fulltext_search(
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'],
body={ body={
'size': limit,
'_source': ['uuid'],
'query': { 'query': {
'bool': { 'bool': {
'filter': filters, 'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
} }
} },
}, },
) )
@ -353,8 +354,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -417,10 +418,10 @@ async def edge_similarity_search(
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'],
size=limit,
body={ body={
'size': limit,
'_source': ['uuid'],
'query': { 'query': {
'knn': { 'knn': {
'fact_embedding': { 'fact_embedding': {
@ -429,7 +430,7 @@ async def edge_similarity_search(
'filter': {'bool': {'filter': filters}}, 'filter': {'bool': {'filter': filters}},
} }
} }
} },
}, },
) )
@ -637,11 +638,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -663,10 +664,10 @@ async def node_fulltext_search(
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'],
size=limit,
body={ body={
'_source': ['uuid'],
'size': limit,
'query': { 'query': {
'bool': { 'bool': {
'filter': filters, 'filter': filters,
@ -680,7 +681,7 @@ async def node_fulltext_search(
} }
], ],
} }
} },
}, },
) )
@ -751,8 +752,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -781,11 +782,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -809,9 +810,9 @@ async def node_similarity_search(
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
params={'routing': route}, params={'routing': route},
_source=['uuid'],
size=limit,
body={ body={
'size': limit,
'_source': ['uuid'],
'query': { 'query': {
'knn': { 'knn': {
'name_embedding': { 'name_embedding': {
@ -820,7 +821,7 @@ async def node_similarity_search(
'filter': {'bool': {'filter': filters}}, 'filter': {'bool': {'filter': filters}},
} }
} }
} },
}, },
) )
@ -837,8 +838,8 @@ async def node_similarity_search(
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -1013,9 +1014,10 @@ async def episode_fulltext_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'],
query={ query={
'size': limit,
'_source': ['uuid'],
'bool': { 'bool': {
'filter': {'terms': group_ids}, 'filter': {'terms': group_ids},
'must': [ 'must': [
@ -1027,7 +1029,7 @@ async def episode_fulltext_search(
} }
} }
], ],
} },
}, },
limit=limit, limit=limit,
) )
@ -1170,8 +1172,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1230,8 +1232,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1373,9 +1375,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1420,9 +1422,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1511,9 +1513,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1583,9 +1585,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1621,9 +1623,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1696,10 +1698,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1769,10 +1771,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1808,10 +1810,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """