cleanup
This commit is contained in:
parent
8e442d4634
commit
b036d45329
4 changed files with 75 additions and 83 deletions
|
|
@ -195,6 +195,9 @@ class GraphDriver(ABC):
|
|||
|
||||
async def create_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for index in aoss_indices:
|
||||
alias_name = index['index_name']
|
||||
|
|
@ -220,10 +223,20 @@ class GraphDriver(ABC):
|
|||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return 0
|
||||
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
|
|
@ -237,7 +250,7 @@ class GraphDriver(ABC):
|
|||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
success, failed = helpers.bulk(client, to_index, stats_only=True)
|
||||
|
||||
return success if failed == 0 else success
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class FalkorDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.FALKORDB
|
||||
aoss_client: None
|
||||
|
||||
def __init__(self, graph: FalkorGraph):
|
||||
self.graph = graph
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
|
|||
|
||||
class KuzuDriver(GraphDriver):
|
||||
provider: GraphProvider = GraphProvider.KUZU
|
||||
aoss_client: None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -209,11 +209,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -343,8 +343,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -620,11 +620,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -731,8 +731,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -761,11 +761,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -810,8 +810,8 @@ async def node_similarity_search(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -989,7 +989,7 @@ async def episode_fulltext_search(
|
|||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': [{'terms': {'group_id': group_ids}}],
|
||||
'filter': {'terms': group_ids},
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
|
|
@ -1005,37 +1005,14 @@ async def episode_fulltext_search(
|
|||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
# Calculate Cosine similarity then return the edge ids
|
||||
input_uuids = []
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Match the edge ides and return the values
|
||||
query = """
|
||||
UNWIND $uuids as i
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.uuid=i.uuid
|
||||
RETURN
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
ORDER BY i.score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuids=input_uuids,
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
# Get nodes
|
||||
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return episodes
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
|
|
@ -1165,8 +1142,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1225,8 +1202,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1368,9 +1345,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1415,9 +1392,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1506,9 +1483,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1578,9 +1555,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1616,9 +1593,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1691,10 +1668,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1764,10 +1741,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1803,10 +1780,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue